aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--RELEASE.md2
-rw-r--r--configure.py37
-rw-r--r--tensorflow/BUILD2
-rw-r--r--tensorflow/c/eager/tape.h7
-rw-r--r--tensorflow/c/python_api.cc2
-rw-r--r--tensorflow/cc/saved_model/BUILD30
-rw-r--r--tensorflow/cc/saved_model/loader.cc70
-rw-r--r--tensorflow/cc/saved_model/reader.cc88
-rw-r--r--tensorflow/cc/saved_model/reader.h39
-rw-r--r--tensorflow/cc/saved_model/reader_test.cc108
-rw-r--r--tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc9
-rw-r--r--tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc7
-rw-r--r--tensorflow/compiler/jit/kernels/xla_launch_op.cc5
-rw-r--r--tensorflow/compiler/jit/xla_compilation_cache.cc18
-rw-r--r--tensorflow/compiler/jit/xla_compile_on_demand_op.cc4
-rw-r--r--tensorflow/compiler/jit/xla_cpu_device.cc1
-rw-r--r--tensorflow/compiler/jit/xla_device.cc70
-rw-r--r--tensorflow/compiler/jit/xla_device.h22
-rw-r--r--tensorflow/compiler/jit/xla_device_context.cc305
-rw-r--r--tensorflow/compiler/jit/xla_device_context.h24
-rw-r--r--tensorflow/compiler/jit/xla_device_ops.h3
-rw-r--r--tensorflow/compiler/jit/xla_gpu_device.cc1
-rw-r--r--tensorflow/compiler/jit/xla_interpreter_device.cc1
-rw-r--r--tensorflow/compiler/jit/xla_launch_util.cc34
-rw-r--r--tensorflow/compiler/jit/xla_launch_util.h9
-rw-r--r--tensorflow/compiler/jit/xla_tensor.cc30
-rw-r--r--tensorflow/compiler/jit/xla_tensor.h26
-rw-r--r--tensorflow/compiler/tests/BUILD82
-rw-r--r--tensorflow/compiler/tests/adagrad_da_test.py165
-rw-r--r--tensorflow/compiler/tests/adamax_test.py139
-rw-r--r--tensorflow/compiler/tests/addsign_test.py142
-rw-r--r--tensorflow/compiler/tests/conv2d_test.py3
-rw-r--r--tensorflow/compiler/tests/eager_test.py48
-rw-r--r--tensorflow/compiler/tests/powersign_test.py142
-rw-r--r--tensorflow/compiler/tests/qr_op_test.py112
-rw-r--r--tensorflow/compiler/tests/random_ops_test.py4
-rw-r--r--tensorflow/compiler/tests/rmsprop_test.py117
-rw-r--r--tensorflow/compiler/tf2xla/BUILD12
-rw-r--r--tensorflow/compiler/tf2xla/graph_compiler.cc3
-rw-r--r--tensorflow/compiler/tf2xla/kernels/BUILD8
-rw-r--r--tensorflow/compiler/tf2xla/kernels/bcast_ops.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/elu_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc8
-rw-r--r--tensorflow/compiler/tf2xla/kernels/pooling_ops.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/qr_op.cc47
-rw-r--r--tensorflow/compiler/tf2xla/kernels/random_ops.cc159
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reduction_ops.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/relu_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reshape_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reverse_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/scan_ops.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/sequence_ops.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/split_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/stack_ops.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/topk_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/training_ops.cc333
-rw-r--r--tensorflow/compiler/tf2xla/kernels/unpack_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/variable_ops.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/while_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/legacy_flags/backend_registration_flags.cc63
-rw-r--r--tensorflow/compiler/tf2xla/legacy_flags/backend_registration_flags.h49
-rw-r--r--tensorflow/compiler/tf2xla/lib/BUILD33
-rw-r--r--tensorflow/compiler/tf2xla/lib/batch_dot.cc2
-rw-r--r--tensorflow/compiler/tf2xla/lib/cholesky.cc2
-rw-r--r--tensorflow/compiler/tf2xla/lib/qr.cc387
-rw-r--r--tensorflow/compiler/tf2xla/lib/qr.h40
-rw-r--r--tensorflow/compiler/tf2xla/lib/scatter.cc4
-rw-r--r--tensorflow/compiler/tf2xla/lib/triangular_solve.cc2
-rw-r--r--tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc2
-rw-r--r--tensorflow/compiler/tf2xla/lib/util.cc32
-rw-r--r--tensorflow/compiler/tf2xla/lib/util_test.cc2
-rw-r--r--tensorflow/compiler/tf2xla/lib/while_loop.cc9
-rw-r--r--tensorflow/compiler/tf2xla/literal_util.cc2
-rw-r--r--tensorflow/compiler/tf2xla/literal_util.h2
-rw-r--r--tensorflow/compiler/tf2xla/literal_util_test.cc5
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla_test.cc5
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler_test.cc74
-rw-r--r--tensorflow/compiler/tf2xla/xla_context.cc2
-rw-r--r--tensorflow/compiler/tf2xla/xla_cpu_backend.cc4
-rw-r--r--tensorflow/compiler/tf2xla/xla_gpu_backend.cc13
-rw-r--r--tensorflow/compiler/tf2xla/xla_helpers.cc4
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.cc71
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.h42
-rw-r--r--tensorflow/compiler/xla/BUILD45
-rw-r--r--tensorflow/compiler/xla/client/BUILD2
-rw-r--r--tensorflow/compiler/xla/client/client.cc2
-rw-r--r--tensorflow/compiler/xla/client/client.h2
-rw-r--r--tensorflow/compiler/xla/client/lib/BUILD3
-rw-r--r--tensorflow/compiler/xla/client/lib/constants.cc8
-rw-r--r--tensorflow/compiler/xla/client/lib/math_test.cc3
-rw-r--r--tensorflow/compiler/xla/client/lib/numeric.cc8
-rw-r--r--tensorflow/compiler/xla/client/lib/numeric.h4
-rw-r--r--tensorflow/compiler/xla/client/lib/testing.cc4
-rw-r--r--tensorflow/compiler/xla/client/xla_client/BUILD3
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_builder.cc162
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_builder.h159
-rw-r--r--tensorflow/compiler/xla/literal.cc1967
-rw-r--r--tensorflow/compiler/xla/literal.h1152
-rw-r--r--tensorflow/compiler/xla/literal_comparison.cc5
-rw-r--r--tensorflow/compiler/xla/literal_comparison.h2
-rw-r--r--tensorflow/compiler/xla/literal_test.cc (renamed from tensorflow/compiler/xla/literal_util_test.cc)540
-rw-r--r--tensorflow/compiler/xla/literal_util.cc2115
-rw-r--r--tensorflow/compiler/xla/literal_util.h1171
-rw-r--r--tensorflow/compiler/xla/packed_literal_reader.cc2
-rw-r--r--tensorflow/compiler/xla/packed_literal_reader.h2
-rw-r--r--tensorflow/compiler/xla/python/BUILD3
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.i2
-rw-r--r--tensorflow/compiler/xla/python/numpy_bridge.cc5
-rw-r--r--tensorflow/compiler/xla/python/numpy_bridge.h2
-rw-r--r--tensorflow/compiler/xla/python/xla_client.py8
-rw-r--r--tensorflow/compiler/xla/reference_util.cc5
-rw-r--r--tensorflow/compiler/xla/reference_util_test.cc46
-rw-r--r--tensorflow/compiler/xla/rpc/grpc_client_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/BUILD117
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc15
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc112
-rw-r--r--tensorflow/compiler/xla/service/batchnorm_expander.cc56
-rw-r--r--tensorflow/compiler/xla/service/batchnorm_expander_test.cc31
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_propagation.cc135
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_propagation_test.cc20
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment_test.cc67
-rw-r--r--tensorflow/compiler/xla/service/buffer_liveness_test.cc26
-rw-r--r--tensorflow/compiler/xla/service/call_graph_test.cc10
-rw-r--r--tensorflow/compiler/xla/service/call_inliner.cc9
-rw-r--r--tensorflow/compiler/xla/service/call_inliner_test.cc18
-rw-r--r--tensorflow/compiler/xla/service/computation_placer.cc2
-rw-r--r--tensorflow/compiler/xla/service/conditional_simplifier.cc2
-rw-r--r--tensorflow/compiler/xla/service/conditional_simplifier_test.cc21
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion.cc8
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion.h13
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion_test.cc103
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD12
-rw-r--r--tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc14
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc7
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h5
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc596
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.h12
-rw-r--r--tensorflow/compiler/xla/service/cpu/sample_harness.cc9
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/BUILD6
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc24
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc77
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/defuser_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor.h2
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h2
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.cc103
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/executable.cc13
-rw-r--r--tensorflow/compiler/xla/service/flatten_call_graph_test.cc14
-rw-r--r--tensorflow/compiler/xla/service/gather_expander.cc3
-rw-r--r--tensorflow/compiler/xla/service/generic_transfer_manager.cc10
-rw-r--r--tensorflow/compiler/xla/service/generic_transfer_manager.h3
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD43
-rw-r--r--tensorflow/compiler/xla/service/gpu/conditional_thunk.cc7
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc28
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc8
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/for_thunk.cc7
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_compiler.cc3
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc12
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc104
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h12
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc8
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_schedule.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/infeed_manager.cc69
-rw-r--r--tensorflow/compiler/xla/service/gpu/infeed_manager.h82
-rw-r--r--tensorflow/compiler/xla/service/gpu/infeed_thunk.cc91
-rw-r--r--tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc1078
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h62
-rw-r--r--tensorflow/compiler/xla/service/gpu/outfeed_manager.cc32
-rw-r--r--tensorflow/compiler/xla/service/gpu/outfeed_manager.h69
-rw-r--r--tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc111
-rw-r--r--tensorflow/compiler/xla/service/gpu/outfeed_thunk.h52
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_insertion.cc11
-rw-r--r--tensorflow/compiler/xla/service/gpu/sequential_thunk.cc7
-rw-r--r--tensorflow/compiler/xla/service/gpu/thunk.h1
-rw-r--r--tensorflow/compiler/xla/service/gpu/while_thunk.cc10
-rw-r--r--tensorflow/compiler/xla/service/gpu/while_transformer.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/while_transformer_test.cc16
-rw-r--r--tensorflow/compiler/xla/service/gpu/xfeed_queue.h89
-rw-r--r--tensorflow/compiler/xla/service/graphviz_example.cc5
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc66
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc58
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h15
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation_test.cc84
-rw-r--r--tensorflow/compiler/xla/service/hlo_constant_folding.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_constant_folding_test.cc16
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc10
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils.cc9
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils_test.cc46
-rw-r--r--tensorflow/compiler/xla/service/hlo_cse.cc8
-rw-r--r--tensorflow/compiler/xla/service/hlo_cse_test.cc97
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc154
-rw-r--r--tensorflow/compiler/xla/service/hlo_dce_test.cc29
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_map.cc13
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_map.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_remover.cc44
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_test.cc38
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_verifier.cc124
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_verifier.h65
-rw-r--r--tensorflow/compiler/xla/service/hlo_element_type_converter.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc64
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.h5
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_test.cc395
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h22
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc165
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h43
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction_test.cc78
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc154
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h72
-rw-r--r--tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_matchers_test.cc7
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_util.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc40
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc29
-rw-r--r--tensorflow/compiler/xla/service/hlo_query.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_reachability_test.cc10
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.cc22
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.h22
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization_test.cc48
-rw-r--r--tensorflow/compiler/xla/service/hlo_runner.cc9
-rw-r--r--tensorflow/compiler/xla/service/hlo_scheduling_test.cc40
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.cc21
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.h11
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding_metadata.cc14
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc89
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier_test.cc51
-rw-r--r--tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.cc180
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.h12
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis_test.cc165
-rw-r--r--tensorflow/compiler/xla/service/inliner_test.cc28
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/interpreter/BUILD2
-rw-r--r--tensorflow/compiler/xla/service/interpreter/executable.cc2
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.cc21
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment_test.cc25
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/BUILD39
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc7
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc83
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc19
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h7
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/ir_array.cc4
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/ir_array.h12
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc5
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h18
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc118
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h80
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc2
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_util.h2
-rw-r--r--tensorflow/compiler/xla/service/platform_util.cc13
-rw-r--r--tensorflow/compiler/xla/service/reshape_mover.cc2
-rw-r--r--tensorflow/compiler/xla/service/reshape_mover_test.cc15
-rw-r--r--tensorflow/compiler/xla/service/shape_inference_test.cc122
-rw-r--r--tensorflow/compiler/xla/service/transfer_manager.cc2
-rw-r--r--tensorflow/compiler/xla/service/transfer_manager.h12
-rw-r--r--tensorflow/compiler/xla/service/transpose_folding_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc85
-rw-r--r--tensorflow/compiler/xla/service/tuple_simplifier_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/while_loop_simplifier_test.cc10
-rw-r--r--tensorflow/compiler/xla/service/while_util.cc9
-rw-r--r--tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc5
-rw-r--r--tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc11
-rw-r--r--tensorflow/compiler/xla/shape_layout.cc8
-rw-r--r--tensorflow/compiler/xla/shape_layout.h4
-rw-r--r--tensorflow/compiler/xla/shape_util.cc58
-rw-r--r--tensorflow/compiler/xla/shape_util.h43
-rw-r--r--tensorflow/compiler/xla/shape_util_test.cc34
-rw-r--r--tensorflow/compiler/xla/tests/BUILD71
-rw-r--r--tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc89
-rw-r--r--tensorflow/compiler/xla/tests/batch_normalization_test.cc94
-rw-r--r--tensorflow/compiler/xla/tests/bfloat16_test.cc18
-rw-r--r--tensorflow/compiler/xla/tests/broadcast_simple_test.cc96
-rw-r--r--tensorflow/compiler/xla/tests/broadcast_test.cc79
-rw-r--r--tensorflow/compiler/xla/tests/call_test.cc20
-rw-r--r--tensorflow/compiler/xla/tests/check_execution_arity_test.cc10
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.cc15
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.h44
-rw-r--r--tensorflow/compiler/xla/tests/client_test.cc10
-rw-r--r--tensorflow/compiler/xla/tests/compilation_cache_test.cc16
-rw-r--r--tensorflow/compiler/xla/tests/compute_constant_test.cc10
-rw-r--r--tensorflow/compiler/xla/tests/concat_test.cc16
-rw-r--r--tensorflow/compiler/xla/tests/conditional_test.cc32
-rw-r--r--tensorflow/compiler/xla/tests/constants_test.cc25
-rw-r--r--tensorflow/compiler/xla/tests/convert_test.cc16
-rw-r--r--tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc3
-rw-r--r--tensorflow/compiler/xla/tests/convolution_test.cc68
-rw-r--r--tensorflow/compiler/xla/tests/convolution_variants_test.cc14
-rw-r--r--tensorflow/compiler/xla/tests/copy_test.cc27
-rw-r--r--tensorflow/compiler/xla/tests/cross_replica_sum_test.cc16
-rw-r--r--tensorflow/compiler/xla/tests/custom_call_test.cc6
-rw-r--r--tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc4
-rw-r--r--tensorflow/compiler/xla/tests/dot_operation_test.cc68
-rw-r--r--tensorflow/compiler/xla/tests/dynamic_ops_test.cc65
-rw-r--r--tensorflow/compiler/xla/tests/execution_profile_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/filecheck.cc5
-rw-r--r--tensorflow/compiler/xla/tests/fusion_test.cc155
-rw-r--r--tensorflow/compiler/xla/tests/gather_operation_test.cc198
-rw-r--r--tensorflow/compiler/xla/tests/half_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.cc10
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.h7
-rw-r--r--tensorflow/compiler/xla/tests/literal_test_util.h31
-rw-r--r--tensorflow/compiler/xla/tests/literal_test_util_test.cc46
-rw-r--r--tensorflow/compiler/xla/tests/llvm_compiler_test.cc5
-rw-r--r--tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc27
-rw-r--r--tensorflow/compiler/xla/tests/llvm_irgen_test_base.h8
-rw-r--r--tensorflow/compiler/xla/tests/local_client_allocation_test.cc4
-rw-r--r--tensorflow/compiler/xla/tests/local_client_execute_test.cc131
-rw-r--r--tensorflow/compiler/xla/tests/local_client_test_base.cc14
-rw-r--r--tensorflow/compiler/xla/tests/map_test.cc52
-rw-r--r--tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc24
-rw-r--r--tensorflow/compiler/xla/tests/multioutput_fusion_test.cc120
-rw-r--r--tensorflow/compiler/xla/tests/pad_test.cc51
-rw-r--r--tensorflow/compiler/xla/tests/params_test.cc60
-rw-r--r--tensorflow/compiler/xla/tests/prng_test.cc4
-rw-r--r--tensorflow/compiler/xla/tests/reduce_hlo_test.cc30
-rw-r--r--tensorflow/compiler/xla/tests/reduce_precision_test.cc15
-rw-r--r--tensorflow/compiler/xla/tests/reduce_test.cc42
-rw-r--r--tensorflow/compiler/xla/tests/reduce_window_test.cc136
-rw-r--r--tensorflow/compiler/xla/tests/replay_test.cc6
-rw-r--r--tensorflow/compiler/xla/tests/reshape_motion_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/reshape_test.cc208
-rw-r--r--tensorflow/compiler/xla/tests/reverse_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/round_trip_transfer_test.cc50
-rw-r--r--tensorflow/compiler/xla/tests/scalar_computations_test.cc25
-rw-r--r--tensorflow/compiler/xla/tests/select_and_scatter_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/slice_test.cc8
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.cc9
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.h2
-rw-r--r--tensorflow/compiler/xla/tests/token_hlo_test.cc20
-rw-r--r--tensorflow/compiler/xla/tests/transfer_manager_test.cc80
-rw-r--r--tensorflow/compiler/xla/tests/tuple_test.cc106
-rw-r--r--tensorflow/compiler/xla/tests/unary_op_test.cc8
-rw-r--r--tensorflow/compiler/xla/tests/while_test.cc69
-rw-r--r--tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc55
-rw-r--r--tensorflow/compiler/xla/text_literal_reader.cc2
-rw-r--r--tensorflow/compiler/xla/text_literal_reader.h2
-rw-r--r--tensorflow/compiler/xla/text_literal_reader_test.cc2
-rw-r--r--tensorflow/compiler/xla/text_literal_writer.cc2
-rw-r--r--tensorflow/compiler/xla/text_literal_writer.h2
-rw-r--r--tensorflow/compiler/xla/text_literal_writer_test.cc6
-rw-r--r--tensorflow/compiler/xla/tools/BUILD6
-rw-r--r--tensorflow/compiler/xla/tools/replay_computation.cc2
-rw-r--r--tensorflow/compiler/xla/tools/show_literal.cc2
-rw-r--r--tensorflow/compiler/xla/tools/show_text_literal.cc2
-rw-r--r--tensorflow/compiler/xla/util.h11
-rw-r--r--tensorflow/contrib/BUILD1
-rw-r--r--tensorflow/contrib/autograph/README.md2
-rw-r--r--tensorflow/contrib/autograph/__init__.py7
-rw-r--r--tensorflow/contrib/autograph/converters/BUILD13
-rw-r--r--tensorflow/contrib/autograph/converters/__init__.py14
-rw-r--r--tensorflow/contrib/autograph/converters/error_handlers.py52
-rw-r--r--tensorflow/contrib/autograph/converters/error_handlers_test.py61
-rw-r--r--tensorflow/contrib/autograph/converters/single_return.py5
-rw-r--r--tensorflow/contrib/autograph/converters/slices.py3
-rw-r--r--tensorflow/contrib/autograph/core/BUILD36
-rw-r--r--tensorflow/contrib/autograph/core/converter.py120
-rw-r--r--tensorflow/contrib/autograph/core/converter_testing.py3
-rw-r--r--tensorflow/contrib/autograph/core/errors.py272
-rw-r--r--tensorflow/contrib/autograph/core/errors_test.py116
-rw-r--r--tensorflow/contrib/autograph/examples/integration_tests/BUILD29
-rw-r--r--tensorflow/contrib/autograph/examples/integration_tests/keras_test.py (renamed from tensorflow/contrib/proto/python/kernel_tests/test_case.py)30
-rw-r--r--tensorflow/contrib/autograph/examples/notebooks/autograph_vs_eager_mnist_benchmark.ipynb664
-rw-r--r--tensorflow/contrib/autograph/examples/notebooks/workshop.ipynb1093
-rw-r--r--tensorflow/contrib/autograph/impl/api.py24
-rw-r--r--tensorflow/contrib/autograph/impl/api_test.py15
-rw-r--r--tensorflow/contrib/autograph/impl/conversion.py9
-rw-r--r--tensorflow/contrib/autograph/impl/conversion_test.py8
-rw-r--r--tensorflow/contrib/autograph/operators/__init__.py2
-rw-r--r--tensorflow/contrib/autograph/pyct/BUILD1
-rw-r--r--tensorflow/contrib/autograph/pyct/anno.py91
-rw-r--r--tensorflow/contrib/autograph/pyct/anno_test.py23
-rw-r--r--tensorflow/contrib/autograph/pyct/ast_util.py175
-rw-r--r--tensorflow/contrib/autograph/pyct/ast_util_test.py142
-rw-r--r--tensorflow/contrib/autograph/pyct/cfg.py137
-rw-r--r--tensorflow/contrib/autograph/pyct/cfg_test.py213
-rw-r--r--tensorflow/contrib/autograph/pyct/compiler.py97
-rw-r--r--tensorflow/contrib/autograph/pyct/compiler_test.py4
-rw-r--r--tensorflow/contrib/autograph/pyct/origin_info.py100
-rw-r--r--tensorflow/contrib/autograph/pyct/qual_names.py28
-rw-r--r--tensorflow/contrib/autograph/pyct/qual_names_test.py9
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/BUILD36
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/__init__.py12
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/annos.py10
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/liveness.py200
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/liveness_test.py149
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions.py273
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions_test.py221
-rw-r--r--tensorflow/contrib/autograph/pyct/templates.py88
-rw-r--r--tensorflow/contrib/autograph/pyct/transformer.py139
-rw-r--r--tensorflow/contrib/autograph/pyct/transformer_test.py77
-rw-r--r--tensorflow/contrib/batching/python/ops/batch_ops.py6
-rw-r--r--tensorflow/contrib/bigtable/BUILD53
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc28
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_lib.h3
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.cc68
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h67
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_range_helpers_test.cc107
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc200
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc113
-rw-r--r--tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc29
-rw-r--r--tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_test.cc55
-rw-r--r--tensorflow/contrib/bigtable/ops/bigtable_ops.cc18
-rw-r--r--tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py140
-rw-r--r--tensorflow/contrib/bigtable/python/ops/bigtable_api.py332
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/estimator.py85
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py47
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/model.py183
-rw-r--r--tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h2
-rw-r--r--tensorflow/contrib/boosted_trees/python/utils/losses.py67
-rw-r--r--tensorflow/contrib/cluster_resolver/BUILD19
-rw-r--r--tensorflow/contrib/cmake/python_modules.txt4
-rw-r--r--tensorflow/contrib/cmake/tf_core_framework.cmake94
-rwxr-xr-xtensorflow/contrib/cmake/tf_python.cmake15
-rw-r--r--tensorflow/contrib/cmake/tf_tests.cmake4
-rw-r--r--tensorflow/contrib/copy_graph/python/util/copy_elements.py6
-rw-r--r--tensorflow/contrib/data/__init__.py2
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD2
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py145
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py17
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py484
-rw-r--r--tensorflow/contrib/data/python/ops/prefetching_ops.py173
-rw-r--r--tensorflow/contrib/data/python/ops/readers.py4
-rw-r--r--tensorflow/contrib/distribute/BUILD1
-rw-r--r--tensorflow/contrib/distribute/__init__.py2
-rw-r--r--tensorflow/contrib/distribute/python/BUILD1
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py67
-rw-r--r--tensorflow/contrib/distribute/python/multi_worker_strategy.py2
-rw-r--r--tensorflow/contrib/distribute/python/values.py56
-rw-r--r--tensorflow/contrib/eager/python/examples/densenet/BUILD29
-rw-r--r--tensorflow/contrib/eager/python/examples/densenet/densenet.py274
-rw-r--r--tensorflow/contrib/eager/python/examples/densenet/densenet_test.py83
-rw-r--r--tensorflow/contrib/eager/python/examples/gan/mnist.py14
-rw-r--r--tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb1184
-rw-r--r--tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb689
-rw-r--r--tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb6
-rw-r--r--tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb6
-rw-r--r--tensorflow/contrib/eager/python/examples/notebooks/custom_layers.ipynb6
-rw-r--r--tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb6
-rw-r--r--tensorflow/contrib/eager/python/examples/notebooks/eager_basics.ipynb6
-rw-r--r--tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py24
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/BUILD2
-rw-r--r--tensorflow/contrib/eager/python/examples/workshop/1_basic.ipynb4
-rw-r--r--tensorflow/contrib/estimator/BUILD29
-rw-r--r--tensorflow/contrib/estimator/__init__.py7
-rw-r--r--tensorflow/contrib/estimator/python/estimator/baseline_test.py8
-rw-r--r--tensorflow/contrib/estimator/python/estimator/early_stopping.py468
-rw-r--r--tensorflow/contrib/estimator/python/estimator/early_stopping_test.py233
-rw-r--r--tensorflow/contrib/gan/BUILD15
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py200
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py227
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/head_impl.py10
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/head_test.py2
-rw-r--r--tensorflow/contrib/kafka/ops/kafka_ops.cc44
-rw-r--r--tensorflow/contrib/layers/python/layers/rev_block_lib.py26
-rw-r--r--tensorflow/contrib/layers/python/layers/rev_block_lib_test.py20
-rw-r--r--tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py14
-rw-r--r--tensorflow/contrib/linear_optimizer/BUILD1
-rw-r--r--tensorflow/contrib/lite/BUILD1
-rw-r--r--tensorflow/contrib/lite/allocation.h1
-rw-r--r--tensorflow/contrib/lite/arena_planner.cc21
-rw-r--r--tensorflow/contrib/lite/arena_planner.h9
-rw-r--r--tensorflow/contrib/lite/arena_planner_test.cc2
-rw-r--r--tensorflow/contrib/lite/build_def.bzl6
-rw-r--r--tensorflow/contrib/lite/builtin_op_data.h14
-rw-r--r--tensorflow/contrib/lite/builtin_ops.h2
-rw-r--r--tensorflow/contrib/lite/delegates/eager/BUILD35
-rw-r--r--tensorflow/contrib/lite/delegates/eager/util.cc47
-rw-r--r--tensorflow/contrib/lite/delegates/eager/util.h35
-rw-r--r--tensorflow/contrib/lite/delegates/eager/util_test.cc100
-rw-r--r--tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc16
-rw-r--r--tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc94
-rw-r--r--tensorflow/contrib/lite/examples/android/app/README.md19
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/assets/pets_labels_list.txt38
-rw-r--r--tensorflow/contrib/lite/g3doc/benchmarks.md178
-rw-r--r--tensorflow/contrib/lite/g3doc/models.md32
-rw-r--r--tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md25
-rw-r--r--tensorflow/contrib/lite/interpreter.cc13
-rw-r--r--tensorflow/contrib/lite/interpreter.h4
-rw-r--r--tensorflow/contrib/lite/interpreter_test.cc16
-rw-r--r--tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java8
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java9
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java21
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java239
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java179
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/BUILD1
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc307
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h79
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc123
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/tensor_jni.h40
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java15
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java236
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java131
-rw-r--r--tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java4
-rw-r--r--tensorflow/contrib/lite/kernels/BUILD31
-rw-r--r--tensorflow/contrib/lite/kernels/arg_min_max.cc (renamed from tensorflow/contrib/lite/kernels/arg_max.cc)69
-rw-r--r--tensorflow/contrib/lite/kernels/arg_min_max_test.cc (renamed from tensorflow/contrib/lite/kernels/arg_max_test.cc)89
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc56
-rw-r--r--tensorflow/contrib/lite/kernels/conv.cc1
-rw-r--r--tensorflow/contrib/lite/kernels/eigen_support.cc54
-rw-r--r--tensorflow/contrib/lite/kernels/eigen_support.h7
-rw-r--r--tensorflow/contrib/lite/kernels/embedding_lookup.cc3
-rw-r--r--tensorflow/contrib/lite/kernels/embedding_lookup_test.cc36
-rw-r--r--tensorflow/contrib/lite/kernels/fake_quant.cc92
-rw-r--r--tensorflow/contrib/lite/kernels/fake_quant_test.cc112
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h73
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h63
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc92
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h226
-rw-r--r--tensorflow/contrib/lite/kernels/internal/quantization_util.h10
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h73
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h173
-rw-r--r--tensorflow/contrib/lite/kernels/internal/types.h23
-rw-r--r--tensorflow/contrib/lite/kernels/lstm.cc3
-rw-r--r--tensorflow/contrib/lite/kernels/lstm_test.cc3
-rw-r--r--tensorflow/contrib/lite/kernels/mul.cc67
-rw-r--r--tensorflow/contrib/lite/kernels/mul_test.cc58
-rw-r--r--tensorflow/contrib/lite/kernels/pooling.cc98
-rw-r--r--tensorflow/contrib/lite/kernels/register.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/select.cc3
-rw-r--r--tensorflow/contrib/lite/kernels/select_test.cc13
-rw-r--r--tensorflow/contrib/lite/kernels/transpose_conv.cc111
-rw-r--r--tensorflow/contrib/lite/kernels/transpose_conv_test.cc121
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc3
-rw-r--r--tensorflow/contrib/lite/model.cc27
-rw-r--r--tensorflow/contrib/lite/model.h1
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.cc15
-rw-r--r--tensorflow/contrib/lite/python/BUILD6
-rw-r--r--tensorflow/contrib/lite/python/interpreter.py22
-rw-r--r--tensorflow/contrib/lite/python/interpreter_test.py23
-rw-r--r--tensorflow/contrib/lite/python/interpreter_wrapper/BUILD1
-rw-r--r--tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc265
-rw-r--r--tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h26
-rw-r--r--tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.i43
-rw-r--r--tensorflow/contrib/lite/schema/schema.fbs20
-rwxr-xr-xtensorflow/contrib/lite/schema/schema_generated.h315
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py57
-rw-r--r--tensorflow/contrib/lite/testing/generated_examples_zip_test.cc8
-rw-r--r--tensorflow/contrib/lite/toco/BUILD3
-rw-r--r--tensorflow/contrib/lite/toco/export_tensorflow.cc22
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc1
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h3
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc25
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc11
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc7
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc17
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc10
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/quantization_util.cc69
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h20
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/quantize.cc12
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc6
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc78
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc112
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc46
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_fake_quant_args_from_vars.cc80
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc19
-rw-r--r--tensorflow/contrib/lite/toco/model.h47
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export.cc18
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc56
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator_test.cc7
-rw-r--r--tensorflow/contrib/lite/toco/toco_tooling.cc12
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc8
-rw-r--r--tensorflow/contrib/lite/tools/BUILD1
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/README.md2
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc3
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/benchmark_model.h1
-rw-r--r--tensorflow/contrib/lite/tools/visualize.py17
-rw-r--r--tensorflow/contrib/metrics/BUILD1
-rw-r--r--tensorflow/contrib/metrics/__init__.py1
-rw-r--r--tensorflow/contrib/metrics/python/metrics/classification.py121
-rw-r--r--tensorflow/contrib/metrics/python/metrics/classification_test.py202
-rw-r--r--tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py2
-rw-r--r--tensorflow/contrib/mpi_collectives/mpi_ops.py163
-rw-r--r--tensorflow/contrib/mpi_collectives/ring.cc80
-rw-r--r--tensorflow/contrib/mpi_collectives/ring.cu.cc117
-rw-r--r--tensorflow/contrib/mpi_collectives/ring.h327
-rw-r--r--tensorflow/contrib/opt/python/training/addsign_test.py6
-rw-r--r--tensorflow/contrib/opt/python/training/ggt.py2
-rw-r--r--tensorflow/contrib/opt/python/training/powersign_test.py2
-rw-r--r--tensorflow/contrib/proto/BUILD4
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/BUILD86
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/build_defs.bzl89
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/decode_proto_fail_test.py68
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py275
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py310
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/defaut_values.TestCase.pbtxt94
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py155
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test_base.py177
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/minmax.TestCase.pbtxt161
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/nested.TestCase.pbtxt16
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/optional.TestCase.pbtxt20
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/promote_unsigned.TestCase.pbtxt29
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py407
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/ragged.TestCase.pbtxt32
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/shaped_batch.TestCase.pbtxt62
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/simple.TestCase.pbtxt21
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/test_example.proto159
-rw-r--r--tensorflow/contrib/quantize/python/quantize.py13
-rw-r--r--tensorflow/contrib/quantize/python/quantize_graph.py4
-rw-r--r--tensorflow/contrib/quantize/python/quantize_parameterized_test.py76
-rw-r--r--tensorflow/contrib/rnn/BUILD1
-rw-r--r--tensorflow/contrib/rnn/__init__.py3
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py161
-rw-r--r--tensorflow/contrib/rnn/python/ops/rnn_cell.py342
-rw-r--r--tensorflow/contrib/rpc/python/kernel_tests/BUILD3
-rw-r--r--tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py52
-rw-r--r--tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_servicer.py8
-rw-r--r--tensorflow/contrib/rpc/python/kernel_tests/test_example.proto147
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py42
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py14
-rw-r--r--tensorflow/contrib/tensorboard/db/BUILD1
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_graph.cc63
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.cc1354
-rw-r--r--tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc11
-rw-r--r--tensorflow/contrib/tensorrt/ops/trt_engine_op.cc10
-rw-r--r--tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc62
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/BUILD1
-rw-r--r--tensorflow/contrib/tpu/BUILD33
-rw-r--r--tensorflow/contrib/tpu/__init__.py2
-rw-r--r--tensorflow/contrib/tpu/python/tpu/keras_support.py152
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_config.py58
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_config_test.py55
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_context.py107
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py201
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_feed.py16
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py6
-rw-r--r--tensorflow/contrib/training/python/training/sgdr_learning_rate_decay.py187
-rw-r--r--tensorflow/contrib/training/python/training/sgdr_learning_rate_decay_test.py145
-rw-r--r--tensorflow/core/BUILD62
-rw-r--r--tensorflow/core/api_def/base_api/api_def_IteratorFromStringHandleV2.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_IteratorV2.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_NonMaxSuppressionWithOverlaps.pbtxt62
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SinkDataset.pbtxt (renamed from tensorflow/core/api_def/base_api/api_def_IdentityDataset.pbtxt)2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Acos.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Acosh.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Add.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_AsString.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Asin.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Asinh.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Atan.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Atan2.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Atanh.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_BatchToSpaceND.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Betainc.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Ceil.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_CheckNumerics.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Cholesky.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Cos.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Cosh.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Cross.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_DecodeBase64.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_DecodeCompressed.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_DecodeJSONExample.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_DecodeRaw.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Dequantize.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Diag.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_DiagPart.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Digamma.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_EncodeBase64.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Equal.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Erfc.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Exp.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Expm1.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_ExtractImagePatches.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_FFT.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxArgs.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxArgsGradient.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVars.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsGradient.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsPerChannel.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsPerChannelGradient.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Floor.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_GatherNd.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Greater.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_GreaterEqual.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_IFFT.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Igamma.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Igammac.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_InvertPermutation.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_IsFinite.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_IsInf.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_IsNan.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Less.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_LessEqual.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Lgamma.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Log.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Log1p.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_LogicalAnd.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_LogicalNot.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_LogicalOr.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_MatchingFiles.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_MatrixBandPart.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_MatrixDeterminant.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_MatrixDiag.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_MatrixDiagPart.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_MatrixInverse.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_MatrixSetDiag.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_MatrixSolve.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_MatrixTriangularSolve.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Maximum.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Minimum.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_NonMaxSuppressionWithOverlaps.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_NotEqual.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_ParseTensor.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Polygamma.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Qr.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_QuantizedConcat.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_ReadFile.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Reciprocal.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_RegexReplace.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Reshape.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_ReverseV2.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Rint.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Rsqrt.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_ScatterNd.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_SegmentMax.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_SegmentMean.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_SegmentMin.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_SegmentProd.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_SegmentSum.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Sin.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Sinh.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_SpaceToBatchND.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_SquaredDifference.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_StringJoin.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_StringStrip.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_StringToHashBucket.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_StringToHashBucketFast.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_StringToHashBucketStrong.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_StringToNumber.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Substr.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Tan.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Tile.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_UnsortedSegmentMax.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_UnsortedSegmentMin.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_UnsortedSegmentProd.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_UnsortedSegmentSum.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_WriteFile.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Zeta.pbtxt2
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc17
-rw-r--r--tensorflow/core/common_runtime/direct_session_test.cc8
-rw-r--r--tensorflow/core/common_runtime/executor.cc23
-rw-r--r--tensorflow/core/common_runtime/executor.h3
-rw-r--r--tensorflow/core/common_runtime/placer.cc55
-rw-r--r--tensorflow/core/common_runtime/placer_test.cc40
-rw-r--r--tensorflow/core/common_runtime/session.cc20
-rw-r--r--tensorflow/core/common_runtime/session_factory.h7
-rw-r--r--tensorflow/core/common_runtime/session_test.cc6
-rw-r--r--tensorflow/core/debug/BUILD55
-rw-r--r--tensorflow/core/debug/debug_gateway.cc122
-rw-r--r--tensorflow/core/debug/debug_gateway.h83
-rw-r--r--tensorflow/core/debug/debug_gateway_test.cc1011
-rw-r--r--tensorflow/core/distributed_runtime/BUILD4
-rw-r--r--tensorflow/core/distributed_runtime/graph_mgr.cc2
-rw-r--r--tensorflow/core/distributed_runtime/master_test.cc2
-rw-r--r--tensorflow/core/distributed_runtime/rpc/BUILD14
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_channel.cc8
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc5
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc2
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc164
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h224
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc2
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc18
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_session.cc15
-rw-r--r--tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc26
-rw-r--r--tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h6
-rw-r--r--tensorflow/core/distributed_runtime/rpc_collective_executor_mgr_test.cc47
-rw-r--r--tensorflow/core/framework/api_def.proto6
-rw-r--r--tensorflow/core/framework/common_shape_fns.cc12
-rw-r--r--tensorflow/core/framework/common_shape_fns.h17
-rw-r--r--tensorflow/core/framework/op_kernel.cc10
-rw-r--r--tensorflow/core/framework/op_kernel.h3
-rw-r--r--tensorflow/core/graph/tensor_id.cc3
-rw-r--r--tensorflow/core/graph/tensor_id.h7
-rw-r--r--tensorflow/core/grappler/costs/BUILD1
-rw-r--r--tensorflow/core/grappler/costs/graph_properties.cc173
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_test.cc248
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_testdata/function_error.pbtxt117
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_testdata/function_switch.pbtxt251
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_testdata/function_switch_2.pbtxt251
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_testdata/function_switch_shapes.pbtxt317
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_testdata/large_function_graph.pbtxt597
-rw-r--r--tensorflow/core/grappler/op_types.cc2
-rw-r--r--tensorflow/core/grappler/op_types.h1
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc171
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.h1
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc42
-rw-r--r--tensorflow/core/grappler/optimizers/data/BUILD34
-rw-r--r--tensorflow/core/grappler/optimizers/data/function_rename.cc51
-rw-r--r--tensorflow/core/grappler/optimizers/data/function_rename.h46
-rw-r--r--tensorflow/core/grappler/optimizers/data/function_rename_test.cc42
-rw-r--r--tensorflow/core/kernels/BUILD19
-rw-r--r--tensorflow/core/kernels/concat_op.cc2
-rw-r--r--tensorflow/core/kernels/data/BUILD10
-rw-r--r--tensorflow/core/kernels/data/captured_function.cc12
-rw-r--r--tensorflow/core/kernels/data/generator_dataset_op.cc3
-rw-r--r--tensorflow/core/kernels/data/identity_dataset_op.cc102
-rw-r--r--tensorflow/core/kernels/data/iterator_ops.cc23
-rw-r--r--tensorflow/core/kernels/data/optimize_dataset_op.cc92
-rw-r--r--tensorflow/core/kernels/data/prefetch_dataset_op.cc7
-rw-r--r--tensorflow/core/kernels/initializable_lookup_table.h2
-rw-r--r--tensorflow/core/kernels/non_max_suppression_op.cc181
-rw-r--r--tensorflow/core/kernels/non_max_suppression_op_test.cc237
-rw-r--r--tensorflow/core/kernels/partitioned_function_ops.cc18
-rw-r--r--tensorflow/core/kernels/resource_variable_ops.cc29
-rw-r--r--tensorflow/core/kernels/resource_variable_ops.h9
-rw-r--r--tensorflow/core/kernels/roll_op.cc3
-rw-r--r--tensorflow/core/kernels/segment_reduction_ops.h12
-rw-r--r--tensorflow/core/kernels/sendrecv_ops.cc1
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt1096
-rw-r--r--tensorflow/core/ops/dataset_ops.cc19
-rw-r--r--tensorflow/core/ops/debug_ops.cc2
-rw-r--r--tensorflow/core/ops/functional_ops.cc6
-rw-r--r--tensorflow/core/ops/image_ops.cc32
-rw-r--r--tensorflow/core/ops/ops.pbtxt147
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system.cc5
-rw-r--r--tensorflow/core/platform/default/build_config/BUILD12
-rw-r--r--tensorflow/core/platform/env.h2
-rw-r--r--tensorflow/core/platform/numa.h62
-rw-r--r--tensorflow/core/platform/numa_test.cc61
-rw-r--r--tensorflow/core/platform/posix/port.cc24
-rw-r--r--tensorflow/core/platform/profile_utils/cpu_utils.cc37
-rw-r--r--tensorflow/core/platform/s3/s3_crypto.cc113
-rw-r--r--tensorflow/core/platform/s3/s3_crypto.h35
-rw-r--r--tensorflow/core/platform/vmodule_benchmark_test.cc (renamed from tensorflow/contrib/lite/java/src/main/native/duration_utils_jni.cc)28
-rw-r--r--tensorflow/core/platform/vmodule_test.cc117
-rw-r--r--tensorflow/core/protobuf/debug.proto4
-rw-r--r--tensorflow/core/public/session.h2
-rw-r--r--tensorflow/core/util/sparse/dim_comparator.h16
-rw-r--r--tensorflow/core/util/sparse/group_iterator.h6
-rw-r--r--tensorflow/core/util/sparse/sparse_tensor.h196
-rw-r--r--tensorflow/core/util/sparse/sparse_tensor_test.cc91
-rw-r--r--tensorflow/core/util/tensor_format.cc2
-rw-r--r--tensorflow/docs_src/deploy/s3.md2
-rw-r--r--tensorflow/docs_src/extend/index.md3
-rw-r--r--tensorflow/docs_src/extend/new_data_formats.md60
-rw-r--r--tensorflow/docs_src/guide/eager.md4
-rw-r--r--tensorflow/docs_src/guide/feature_columns.md6
-rw-r--r--tensorflow/docs_src/guide/graph_viz.md3
-rw-r--r--tensorflow/docs_src/guide/index.md15
-rw-r--r--tensorflow/docs_src/guide/leftnav_files4
-rw-r--r--tensorflow/docs_src/install/install_linux.md2
-rw-r--r--tensorflow/docs_src/javascript/index.md5
-rw-r--r--tensorflow/docs_src/javascript/leftnav_files1
-rw-r--r--tensorflow/docs_src/performance/xla/operation_semantics.md35
-rw-r--r--tensorflow/docs_src/tutorials/_index.yaml13
-rw-r--r--tensorflow/docs_src/tutorials/_toc.yaml24
-rw-r--r--tensorflow/docs_src/tutorials/estimators/cnn.md (renamed from tensorflow/docs_src/tutorials/images/layers.md)0
-rw-r--r--tensorflow/docs_src/tutorials/images/deep_cnn.md14
-rw-r--r--tensorflow/docs_src/tutorials/images/image_recognition.md2
-rw-r--r--tensorflow/docs_src/tutorials/representation/linear.md10
-rw-r--r--tensorflow/docs_src/tutorials/representation/wide.md461
-rw-r--r--tensorflow/docs_src/tutorials/representation/wide_and_deep.md243
-rw-r--r--tensorflow/docs_src/tutorials/representation/word2vec.md10
-rw-r--r--tensorflow/examples/speech_commands/BUILD1
-rw-r--r--tensorflow/examples/speech_commands/freeze.py64
-rw-r--r--tensorflow/examples/speech_commands/freeze_test.py54
-rw-r--r--tensorflow/examples/speech_commands/generate_streaming_test_wav.py10
-rw-r--r--tensorflow/examples/speech_commands/input_data.py135
-rw-r--r--tensorflow/examples/speech_commands/input_data_test.py87
-rw-r--r--tensorflow/examples/speech_commands/models.py302
-rw-r--r--tensorflow/examples/speech_commands/models_test.py40
-rw-r--r--tensorflow/examples/speech_commands/train.py58
-rw-r--r--tensorflow/go/graph.go14
-rw-r--r--tensorflow/go/op/scope.go31
-rw-r--r--tensorflow/go/op/scope_test.go15
-rw-r--r--tensorflow/go/op/wrappers.go642
-rw-r--r--tensorflow/go/operation.go6
-rw-r--r--tensorflow/go/operation_test.go23
-rw-r--r--tensorflow/java/maven/hadoop/pom.xml196
-rw-r--r--tensorflow/java/maven/libtensorflow/pom.xml2
-rw-r--r--tensorflow/java/maven/libtensorflow_jni/pom.xml2
-rw-r--r--tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml2
-rw-r--r--tensorflow/java/maven/pom.xml2
-rw-r--r--tensorflow/java/maven/proto/pom.xml2
-rw-r--r--tensorflow/java/maven/run_inside_container.sh5
-rw-r--r--tensorflow/java/maven/spark-connector/pom.xml355
-rw-r--r--tensorflow/java/maven/tensorflow/pom.xml2
-rw-r--r--tensorflow/java/src/gen/cc/java_defs.h2
-rw-r--r--tensorflow/java/src/gen/cc/op_generator.h2
-rw-r--r--tensorflow/java/src/gen/cc/op_specs.cc148
-rw-r--r--tensorflow/java/src/gen/cc/op_specs.h40
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Input.java48
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/TFBool.java30
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/TFDouble.java30
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/TFFloat.java30
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/TFInt32.java30
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/TFInt64.java30
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/TFString.java27
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/TFType.java20
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/TFUInt8.java30
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/Types.java52
-rw-r--r--tensorflow/java/src/main/native/session_jni.cc10
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/SavedModelBundleTest.java2
-rw-r--r--tensorflow/python/BUILD40
-rw-r--r--tensorflow/python/data/kernel_tests/BUILD1
-rw-r--r--tensorflow/python/data/kernel_tests/iterator_ops_test.py64
-rw-r--r--tensorflow/python/data/kernel_tests/map_dataset_op_test.py7
-rw-r--r--tensorflow/python/data/ops/BUILD2
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py15
-rw-r--r--tensorflow/python/data/ops/iterator_ops.py63
-rw-r--r--tensorflow/python/debug/BUILD2
-rw-r--r--tensorflow/python/eager/backprop.py11
-rw-r--r--tensorflow/python/eager/backprop_test.py14
-rw-r--r--tensorflow/python/eager/function.py134
-rw-r--r--tensorflow/python/eager/function_test.py66
-rw-r--r--tensorflow/python/eager/graph_callable.py4
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc81
-rw-r--r--tensorflow/python/eager/pywrap_tfe_test.py43
-rw-r--r--tensorflow/python/estimator/BUILD8
-rw-r--r--tensorflow/python/estimator/api/BUILD5
-rw-r--r--tensorflow/python/estimator/canned/baseline_test.py10
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees.py4
-rw-r--r--tensorflow/python/estimator/canned/dnn_testing_utils.py4
-rw-r--r--tensorflow/python/estimator/canned/head.py22
-rw-r--r--tensorflow/python/estimator/canned/head_test.py31
-rw-r--r--tensorflow/python/estimator/canned/linear_testing_utils.py14
-rw-r--r--tensorflow/python/estimator/estimator.py4
-rw-r--r--tensorflow/python/estimator/keras.py23
-rw-r--r--tensorflow/python/estimator/keras_test.py23
-rw-r--r--tensorflow/python/estimator/run_config.py11
-rw-r--r--tensorflow/python/estimator/training.py4
-rw-r--r--tensorflow/python/feature_column/BUILD68
-rw-r--r--tensorflow/python/feature_column/feature_column_v2.py3600
-rw-r--r--tensorflow/python/feature_column/feature_column_v2_test.py6583
-rw-r--r--tensorflow/python/framework/common_shapes.py12
-rw-r--r--tensorflow/python/framework/function_test.py16
-rw-r--r--tensorflow/python/framework/ops.py119
-rw-r--r--tensorflow/python/framework/tensor_util_test.py72
-rw-r--r--tensorflow/python/framework/traceable_stack.py135
-rw-r--r--tensorflow/python/framework/traceable_stack_test.py133
-rwxr-xr-xtensorflow/python/keras/BUILD2
-rw-r--r--tensorflow/python/keras/applications/mobilenet.py22
-rw-r--r--tensorflow/python/keras/backend.py30
-rw-r--r--tensorflow/python/keras/callbacks.py34
-rw-r--r--tensorflow/python/keras/callbacks_test.py65
-rw-r--r--tensorflow/python/keras/engine/base_layer.py95
-rw-r--r--tensorflow/python/keras/engine/training.py2
-rw-r--r--tensorflow/python/keras/engine/training_arrays.py24
-rw-r--r--tensorflow/python/keras/engine/training_eager.py14
-rw-r--r--tensorflow/python/keras/engine/training_generator.py20
-rw-r--r--tensorflow/python/keras/initializers.py51
-rw-r--r--tensorflow/python/keras/layers/core.py9
-rw-r--r--tensorflow/python/keras/layers/embeddings.py5
-rw-r--r--tensorflow/python/keras/layers/normalization.py30
-rw-r--r--tensorflow/python/keras/layers/normalization_test.py18
-rw-r--r--tensorflow/python/keras/layers/recurrent.py9
-rw-r--r--tensorflow/python/keras/layers/wrappers.py116
-rw-r--r--tensorflow/python/keras/layers/wrappers_test.py60
-rw-r--r--tensorflow/python/keras/models_test.py9
-rw-r--r--tensorflow/python/kernel_tests/BUILD3
-rw-r--r--tensorflow/python/kernel_tests/functional_ops_test.py6
-rw-r--r--tensorflow/python/kernel_tests/init_ops_test.py6
-rw-r--r--tensorflow/python/kernel_tests/linalg/BUILD2
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py29
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py7
-rw-r--r--tensorflow/python/kernel_tests/resource_variable_ops_test.py57
-rw-r--r--tensorflow/python/kernel_tests/rnn_test.py6
-rw-r--r--tensorflow/python/kernel_tests/variable_scope_test.py68
-rw-r--r--tensorflow/python/kernel_tests/variables_test.py2
-rw-r--r--tensorflow/python/layers/base.py41
-rw-r--r--tensorflow/python/layers/base_test.py26
-rw-r--r--tensorflow/python/lib/core/numpy.h3
-rw-r--r--tensorflow/python/lib/core/py_util.cc3
-rw-r--r--tensorflow/python/ops/functional_ops.py2
-rw-r--r--tensorflow/python/ops/image_ops_impl.py65
-rw-r--r--tensorflow/python/ops/init_ops.py4
-rw-r--r--tensorflow/python/ops/linalg/linear_operator_diag.py5
-rw-r--r--tensorflow/python/ops/linalg/linear_operator_low_rank_update.py31
-rw-r--r--tensorflow/python/ops/linalg/linear_operator_lower_triangular.py8
-rw-r--r--tensorflow/python/ops/linalg_ops.py2
-rw-r--r--tensorflow/python/ops/logging_ops.py9
-rw-r--r--tensorflow/python/ops/math_ops.py14
-rw-r--r--tensorflow/python/ops/metrics_impl.py20
-rw-r--r--tensorflow/python/ops/parallel_for/pfor.py2
-rw-r--r--tensorflow/python/ops/resource_variable_ops.py120
-rw-r--r--tensorflow/python/ops/rnn_cell_impl.py56
-rw-r--r--tensorflow/python/ops/script_ops.py2
-rw-r--r--tensorflow/python/ops/state_ops.py5
-rw-r--r--tensorflow/python/ops/variable_scope.py72
-rw-r--r--tensorflow/python/ops/variables.py4
-rw-r--r--tensorflow/python/platform/benchmark.py10
-rw-r--r--tensorflow/python/platform/self_check.py2
-rw-r--r--tensorflow/python/tools/api/generator/BUILD (renamed from tensorflow/tools/api/generator/BUILD)17
-rw-r--r--tensorflow/python/tools/api/generator/api_gen.bzl (renamed from tensorflow/tools/api/generator/api_gen.bzl)61
-rw-r--r--tensorflow/python/tools/api/generator/create_python_api.py (renamed from tensorflow/tools/api/generator/create_python_api.py)47
-rw-r--r--tensorflow/python/tools/api/generator/create_python_api_test.py (renamed from tensorflow/tools/api/generator/create_python_api_test.py)11
-rw-r--r--tensorflow/python/tools/api/generator/doc_srcs.py (renamed from tensorflow/tools/api/generator/doc_srcs.py)0
-rw-r--r--tensorflow/python/tools/api/generator/doc_srcs_test.py (renamed from tensorflow/tools/api/generator/doc_srcs_test.py)4
-rw-r--r--tensorflow/python/training/distribute.py56
-rw-r--r--tensorflow/python/training/optimizer.py7
-rw-r--r--tensorflow/python/training/quantize_training.i2
-rw-r--r--tensorflow/python/util/deprecation.py38
-rw-r--r--tensorflow/python/util/deprecation_test.py6
-rw-r--r--tensorflow/python/util/py_checkpoint_reader.i1
-rw-r--r--tensorflow/python/util/stat_summarizer.i25
-rw-r--r--tensorflow/python/util/tf_export.py45
-rw-r--r--tensorflow/python/util/tf_export_test.py2
-rw-r--r--tensorflow/python/util/tf_inspect.py10
-rw-r--r--tensorflow/python/util/tf_inspect_test.py12
-rw-r--r--tensorflow/python/util/tf_stack.py97
-rw-r--r--tensorflow/security/advisory/tfsa-2018-001.md2
-rw-r--r--tensorflow/security/index.md2
-rw-r--r--tensorflow/stream_executor/event.cc11
-rw-r--r--tensorflow/stream_executor/event.h3
-rw-r--r--tensorflow/stream_executor/host/host_gpu_executor.cc11
-rw-r--r--tensorflow/stream_executor/stream.cc19
-rw-r--r--tensorflow/tensorflow.bzl43
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-variable-scope.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.image.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.initializers.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-activation.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-activity-regularization.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-add.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-alpha-dropout.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-average.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-batch-normalization.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-concatenate.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-conv1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d-transpose.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d-transpose.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-cu-d-n-n-g-r-u.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-cu-d-n-n-l-s-t-m.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-dense.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-dot.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-dropout.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-e-l-u.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-embedding.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-flatten.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-dropout.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-noise.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-input-layer.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-lambda.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-layer.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-leaky-re-l-u.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-masking.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-maximum.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-minimum.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-multiply.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-p-re-l-u.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-permute.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-re-l-u.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-repeat-vector.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-reshape.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-softmax.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-subtract.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-time-distributed.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-wrapper.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.layers.-average-pooling1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.layers.-average-pooling2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.layers.-average-pooling3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.layers.-batch-normalization.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.layers.-conv1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.layers.-conv2-d-transpose.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.layers.-conv2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.layers.-conv3-d-transpose.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.layers.-conv3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.layers.-dense.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.layers.-dropout.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.layers.-flatten.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.layers.-layer.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.layers.-max-pooling1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.layers.-max-pooling2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.layers.-max-pooling3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.layers.-separable-conv1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.layers.-separable-conv2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.pbtxt6
-rw-r--r--tensorflow/tools/api/lib/python_object_to_proto_visitor.py3
-rw-r--r--tensorflow/tools/api/tests/api_compatibility_test.py42
-rwxr-xr-xtensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh28
-rwxr-xr-xtensorflow/tools/ci_build/update_version.py10
-rw-r--r--tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh4
-rw-r--r--tensorflow/tools/ci_build/windows/bazel/common_env.sh6
-rw-r--r--tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh6
-rw-r--r--tensorflow/tools/compatibility/ast_edits.py502
-rw-r--r--tensorflow/tools/docker/Dockerfile.devel2
-rw-r--r--tensorflow/tools/docker/Dockerfile.devel-cpu-mkl83
-rw-r--r--tensorflow/tools/docker/Dockerfile.devel-gpu4
-rw-r--r--tensorflow/tools/docker/Dockerfile.devel-gpu-cuda9-cudnn7115
-rwxr-xr-xtensorflow/tools/docker/Dockerfile.devel-mkl2
-rw-r--r--tensorflow/tools/docker/Dockerfile.gpu1
-rw-r--r--tensorflow/tools/pip_package/BUILD1
-rwxr-xr-xtensorflow/tools/pip_package/build_pip_package.sh2
-rw-r--r--tensorflow/workspace.bzl59
-rw-r--r--third_party/clang_toolchain/download_clang.bzl8
-rw-r--r--third_party/codegen.BUILD16
-rw-r--r--third_party/gpus/crosstool/BUILD.tpl20
-rw-r--r--third_party/gpus/crosstool/CROSSTOOL.tpl869
-rwxr-xr-xthird_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl6
-rw-r--r--third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.bat.tpl20
-rw-r--r--third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl192
-rw-r--r--third_party/gpus/cuda/BUILD.windows.tpl163
-rw-r--r--third_party/gpus/cuda_configure.bzl2163
-rw-r--r--third_party/llvm/llvm.autogenerated.BUILD235
-rw-r--r--third_party/llvm/llvm.bzl13
-rw-r--r--third_party/mkl/LICENSE201
-rw-r--r--third_party/mkl_dnn/build_defs.bzl4
-rw-r--r--third_party/nanopb.BUILD23
-rw-r--r--third_party/nasm.BUILD180
-rw-r--r--tools/bazel.rc2
1199 files changed, 54463 insertions, 19262 deletions
diff --git a/RELEASE.md b/RELEASE.md
index 4b03394427..7bb1e3e1c8 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -21,7 +21,7 @@
* The [distributions.Bijector](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/contrib/distributions/bijectors/Bijector)
API supports broadcasting for Bijectors with new API changes.
-## Breaking Chances
+## Breaking Changes
* If you're opening empty variable scopes; replace `variable_scope('', ...)` by
`variable_scope(tf.get_variable_scope(), ...)`.
* Headers used for building custom ops have been moved from site-packages/external into site-packages/tensorflow/include/external.
diff --git a/configure.py b/configure.py
index 31a83b4a15..d411214817 100644
--- a/configure.py
+++ b/configure.py
@@ -35,7 +35,7 @@ except ImportError:
_DEFAULT_CUDA_VERSION = '9.0'
_DEFAULT_CUDNN_VERSION = '7'
-_DEFAULT_NCCL_VERSION = '1.3'
+_DEFAULT_NCCL_VERSION = '2.2'
_DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,5.2'
_DEFAULT_CUDA_PATH = '/usr/local/cuda'
_DEFAULT_CUDA_PATH_LINUX = '/opt/cuda'
@@ -835,6 +835,8 @@ def set_tf_cuda_version(environ_cp):
'[Default is %s]: ') % (tf_cuda_version, default_cuda_path)
cuda_toolkit_path = get_from_env_or_user_or_default(
environ_cp, 'CUDA_TOOLKIT_PATH', ask_cuda_path, default_cuda_path)
+ if is_windows() or is_cygwin():
+ cuda_toolkit_path = cygpath(cuda_toolkit_path)
if is_windows():
cuda_rt_lib_path = 'lib/x64/cudart.lib'
@@ -1095,8 +1097,10 @@ def set_tf_nccl_install_path(environ_cp):
raise ValueError('Currently NCCL is only supported on Linux platforms.')
ask_nccl_version = (
- 'Please specify the NCCL version you want to use. '
- '[Leave empty to default to NCCL %s]: ') % _DEFAULT_NCCL_VERSION
+ 'Please specify the NCCL version you want to use. If NCCL %s is not '
+ 'installed, then you can use version 1.3 that can be fetched '
+ 'automatically but it may have worse performance with multiple GPUs. '
+ '[Default is %s]: ') % (_DEFAULT_NCCL_VERSION, _DEFAULT_NCCL_VERSION)
for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS):
tf_nccl_version = get_from_env_or_user_or_default(
@@ -1234,28 +1238,13 @@ def set_tf_cuda_compute_capabilities(environ_cp):
def set_other_cuda_vars(environ_cp):
"""Set other CUDA related variables."""
- if is_windows():
- # The following three variables are needed for MSVC toolchain configuration
- # in Bazel
- environ_cp['CUDA_PATH'] = environ_cp.get('CUDA_TOOLKIT_PATH')
- environ_cp['CUDA_COMPUTE_CAPABILITIES'] = environ_cp.get(
- 'TF_CUDA_COMPUTE_CAPABILITIES')
- environ_cp['NO_WHOLE_ARCHIVE_OPTION'] = 1
- write_action_env_to_bazelrc('CUDA_PATH', environ_cp.get('CUDA_PATH'))
- write_action_env_to_bazelrc('CUDA_COMPUTE_CAPABILITIE',
- environ_cp.get('CUDA_COMPUTE_CAPABILITIE'))
- write_action_env_to_bazelrc('NO_WHOLE_ARCHIVE_OPTION',
- environ_cp.get('NO_WHOLE_ARCHIVE_OPTION'))
- write_to_bazelrc('build --config=win-cuda')
- write_to_bazelrc('test --config=win-cuda')
+ # If CUDA is enabled, always use GPU during build and test.
+ if environ_cp.get('TF_CUDA_CLANG') == '1':
+ write_to_bazelrc('build --config=cuda_clang')
+ write_to_bazelrc('test --config=cuda_clang')
else:
- # If CUDA is enabled, always use GPU during build and test.
- if environ_cp.get('TF_CUDA_CLANG') == '1':
- write_to_bazelrc('build --config=cuda_clang')
- write_to_bazelrc('test --config=cuda_clang')
- else:
- write_to_bazelrc('build --config=cuda')
- write_to_bazelrc('test --config=cuda')
+ write_to_bazelrc('build --config=cuda')
+ write_to_bazelrc('test --config=cuda')
def set_host_cxx_compiler(environ_cp):
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 51eea94847..518c2b0489 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -20,7 +20,7 @@ load(
"tf_additional_binary_deps",
)
load(
- "//tensorflow/tools/api/generator:api_gen.bzl",
+ "//tensorflow/python/tools/api/generator:api_gen.bzl",
"gen_api_init_files", # @unused
)
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h
index 734e712daa..1adb0458c3 100644
--- a/tensorflow/c/eager/tape.h
+++ b/tensorflow/c/eager/tape.h
@@ -520,7 +520,12 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
}
} else {
any_gradient_nonzero = true;
- auto new_gradients = vspace.AggregateGradients(grad_it->second);
+ Gradient* new_gradients = nullptr;
+ if (grad_it->second.size() == 1) {
+ new_gradients = grad_it->second.at(0);
+ } else {
+ new_gradients = vspace.AggregateGradients(grad_it->second);
+ }
if (sources_set.find(grad_it->first) == sources_set.end()) {
gradients.erase(grad_it);
} else {
diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc
index e18fdf6c57..8486b585c8 100644
--- a/tensorflow/c/python_api.cc
+++ b/tensorflow/c/python_api.cc
@@ -155,7 +155,7 @@ void SetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output,
tensorflow::shape_inference::ShapeHandle shape;
status->status =
ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape);
- if (status->status.ok()) return;
+ if (!status->status.ok()) return;
shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype());
}
ic->set_output_handle_shapes_and_types(output.index, shapes_and_types);
diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD
index 06a3be18e0..730b1b669b 100644
--- a/tensorflow/cc/saved_model/BUILD
+++ b/tensorflow/cc/saved_model/BUILD
@@ -34,6 +34,35 @@ cc_library(
)
cc_library(
+ name = "reader",
+ srcs = ["reader.cc"],
+ hdrs = ["reader.h"],
+ deps = [
+ ":constants",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+tf_cc_test(
+ name = "reader_test",
+ srcs = ["reader_test.cc"],
+ data = [
+ ":saved_model_half_plus_two",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":constants",
+ ":reader",
+ ":tag_constants",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
+
+cc_library(
name = "loader",
hdrs = ["loader.h"],
deps = [
@@ -54,6 +83,7 @@ cc_library(
hdrs = ["loader.h"],
deps = [
":constants",
+ ":reader",
] + if_not_mobile([
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc
index faa1e378d0..07807ed2f3 100644
--- a/tensorflow/cc/saved_model/loader.cc
+++ b/tensorflow/cc/saved_model/loader.cc
@@ -18,8 +18,10 @@ limitations under the License.
#include <unordered_set>
#include "tensorflow/cc/saved_model/constants.h"
+#include "tensorflow/cc/saved_model/reader.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/monitoring/counter.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/protobuf_internal.h"
@@ -43,56 +45,6 @@ auto* load_latency = monitoring::Counter<1>::New(
constexpr char kLoadAttemptFail[] = "fail";
constexpr char kLoadAttemptSuccess[] = "success";
-Status ReadSavedModel(const string& export_dir, SavedModel* saved_model_proto) {
- const string saved_model_pb_path =
- io::JoinPath(export_dir, kSavedModelFilenamePb);
- if (Env::Default()->FileExists(saved_model_pb_path).ok()) {
- return ReadBinaryProto(Env::Default(), saved_model_pb_path,
- saved_model_proto);
- }
- const string saved_model_pbtxt_path =
- io::JoinPath(export_dir, kSavedModelFilenamePbTxt);
- if (Env::Default()->FileExists(saved_model_pbtxt_path).ok()) {
- return ReadTextProto(Env::Default(), saved_model_pbtxt_path,
- saved_model_proto);
- }
- return Status(error::Code::NOT_FOUND,
- "Could not find SavedModel .pb or .pbtxt at supplied export "
- "directory path: " +
- export_dir);
-}
-
-string GetTagsAsString(const std::unordered_set<string>& tags) {
- string tags_as_string = "{ ";
- for (const string& tag : tags) {
- tags_as_string = strings::StrCat(tags_as_string, tag, " ");
- }
- tags_as_string = strings::StrCat(tags_as_string, "}");
- return tags_as_string;
-}
-
-Status FindMetaGraphDefToLoad(const SavedModel& saved_model_proto,
- const std::unordered_set<string>& tags,
- MetaGraphDef* meta_graph_def_to_load) {
- for (const MetaGraphDef& meta_graph_def : saved_model_proto.meta_graphs()) {
- // Get tags from the meta_graph_def.
- std::unordered_set<string> graph_tags;
- for (const string& tag : meta_graph_def.meta_info_def().tags()) {
- graph_tags.insert(tag);
- }
- // Match with the set of tags provided.
- if (graph_tags == tags) {
- *meta_graph_def_to_load = meta_graph_def;
- return Status::OK();
- }
- }
- return Status(error::Code::NOT_FOUND,
- "Could not find meta graph def matching supplied tags: " +
- GetTagsAsString(tags) +
- ". To inspect available tag-sets in the SavedModel, please "
- "use the SavedModel CLI: `saved_model_cli`");
-}
-
Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def,
const SessionOptions& session_options,
std::unique_ptr<Session>* session) {
@@ -235,18 +187,8 @@ Status LoadSavedModelInternal(const SessionOptions& session_options,
const string& export_dir,
const std::unordered_set<string>& tags,
SavedModelBundle* const bundle) {
- if (!MaybeSavedModelDirectory(export_dir)) {
- return Status(error::Code::NOT_FOUND,
- "SavedModel not found in export directory: " + export_dir);
- }
- LOG(INFO) << "Loading SavedModel with tags: " << GetTagsAsString(tags)
- << "; from: " << export_dir;
-
- SavedModel saved_model_proto;
- TF_RETURN_IF_ERROR(ReadSavedModel(export_dir, &saved_model_proto));
-
- TF_RETURN_IF_ERROR(
- FindMetaGraphDefToLoad(saved_model_proto, tags, &bundle->meta_graph_def));
+ TF_RETURN_IF_ERROR(ReadMetaGraphDefFromSavedModel(export_dir, tags,
+ &bundle->meta_graph_def));
TF_RETURN_IF_ERROR(LoadMetaGraphIntoSession(
bundle->meta_graph_def, session_options, &bundle->session));
@@ -288,8 +230,8 @@ Status LoadSavedModel(const SessionOptions& session_options,
return end_microseconds - start_microseconds;
}();
auto log_and_count = [&](const string& status_str) {
- LOG(INFO) << "SavedModel load for tags " << GetTagsAsString(tags)
- << "; Status: " << status_str << ". Took "
+ LOG(INFO) << "SavedModel load for tags { " << str_util::Join(tags, " ")
+ << " }; Status: " << status_str << ". Took "
<< load_latency_microsecs << " microseconds.";
load_attempt_count->GetCell(export_dir, status_str)->IncrementBy(1);
};
diff --git a/tensorflow/cc/saved_model/reader.cc b/tensorflow/cc/saved_model/reader.cc
new file mode 100644
index 0000000000..2146c8a197
--- /dev/null
+++ b/tensorflow/cc/saved_model/reader.cc
@@ -0,0 +1,88 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/saved_model/reader.h"
+
+#include <unordered_set>
+
+#include "tensorflow/cc/saved_model/constants.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/protobuf/saved_model.pb.h"
+
+namespace tensorflow {
+namespace {
+
+Status ReadSavedModel(const string& export_dir, SavedModel* saved_model_proto) {
+ LOG(INFO) << "Reading SavedModel from: " << export_dir;
+
+ const string saved_model_pb_path =
+ io::JoinPath(export_dir, kSavedModelFilenamePb);
+ if (Env::Default()->FileExists(saved_model_pb_path).ok()) {
+ return ReadBinaryProto(Env::Default(), saved_model_pb_path,
+ saved_model_proto);
+ }
+ const string saved_model_pbtxt_path =
+ io::JoinPath(export_dir, kSavedModelFilenamePbTxt);
+ if (Env::Default()->FileExists(saved_model_pbtxt_path).ok()) {
+ return ReadTextProto(Env::Default(), saved_model_pbtxt_path,
+ saved_model_proto);
+ }
+ return Status(error::Code::NOT_FOUND,
+ "Could not find SavedModel .pb or .pbtxt at supplied export "
+ "directory path: " +
+ export_dir);
+}
+
+Status FindMetaGraphDef(const SavedModel& saved_model_proto,
+ const std::unordered_set<string>& tags,
+ MetaGraphDef* meta_graph_def) {
+ LOG(INFO) << "Reading meta graph with tags { " << str_util::Join(tags, " ")
+ << " }";
+ for (const MetaGraphDef& graph_def : saved_model_proto.meta_graphs()) {
+ // Get tags from the graph_def.
+ std::unordered_set<string> graph_tags;
+ for (const string& tag : graph_def.meta_info_def().tags()) {
+ graph_tags.insert(tag);
+ }
+ // Match with the set of tags provided.
+ if (graph_tags == tags) {
+ *meta_graph_def = graph_def;
+ return Status::OK();
+ }
+ }
+ return Status(
+ error::Code::NOT_FOUND,
+ strings::StrCat(
+ "Could not find meta graph def matching supplied tags: { ",
+ str_util::Join(tags, " "),
+ " }. To inspect available tag-sets in the SavedModel, please "
+ "use the SavedModel CLI: `saved_model_cli`"));
+}
+
+} // namespace
+
+Status ReadMetaGraphDefFromSavedModel(const string& export_dir,
+ const std::unordered_set<string>& tags,
+ MetaGraphDef* const meta_graph_def) {
+ SavedModel saved_model_proto;
+ TF_RETURN_IF_ERROR(ReadSavedModel(export_dir, &saved_model_proto));
+ TF_RETURN_IF_ERROR(FindMetaGraphDef(saved_model_proto, tags, meta_graph_def));
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/cc/saved_model/reader.h b/tensorflow/cc/saved_model/reader.h
new file mode 100644
index 0000000000..5815108df2
--- /dev/null
+++ b/tensorflow/cc/saved_model/reader.h
@@ -0,0 +1,39 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+/// Functions to read the SavedModel proto, or parts of it.
+
+#ifndef TENSORFLOW_CC_SAVED_MODEL_READER_H_
+#define TENSORFLOW_CC_SAVED_MODEL_READER_H_
+
+#include <string>
+#include <unordered_set>
+
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/protobuf/meta_graph.pb.h"
+
+namespace tensorflow {
+
+// Reads the SavedModel proto from saved_model.pb(txt) in the given directory,
+// finds the MetaGraphDef that matches the given set of tags and writes it to
+// the `meta_graph_def` parameter. Returns a failure status when the SavedModel
+// file does not exist or no MetaGraphDef matches the tags.
+Status ReadMetaGraphDefFromSavedModel(const string& export_dir,
+ const std::unordered_set<string>& tags,
+ MetaGraphDef* const meta_graph_def);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CC_SAVED_MODEL_READER_H_
diff --git a/tensorflow/cc/saved_model/reader_test.cc b/tensorflow/cc/saved_model/reader_test.cc
new file mode 100644
index 0000000000..620e9c2eec
--- /dev/null
+++ b/tensorflow/cc/saved_model/reader_test.cc
@@ -0,0 +1,108 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/saved_model/reader.h"
+
+#include "tensorflow/cc/saved_model/constants.h"
+#include "tensorflow/cc/saved_model/tag_constants.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+constexpr char kTestDataPbTxt[] =
+ "cc/saved_model/testdata/half_plus_two_pbtxt/00000123";
+constexpr char kTestDataSharded[] =
+ "cc/saved_model/testdata/half_plus_two/00000123";
+
+class ReaderTest : public ::testing::Test {
+ protected:
+ ReaderTest() {}
+
+ void CheckMetaGraphDef(const MetaGraphDef& meta_graph_def) {
+ const auto& tags = meta_graph_def.meta_info_def().tags();
+ EXPECT_TRUE(std::find(tags.begin(), tags.end(), kSavedModelTagServe) !=
+ tags.end());
+ EXPECT_NE(meta_graph_def.meta_info_def().tensorflow_version(), "");
+ EXPECT_EQ(
+ meta_graph_def.signature_def().at("serving_default").method_name(),
+ "tensorflow/serving/predict");
+ }
+};
+
+TEST_F(ReaderTest, TagMatch) {
+ MetaGraphDef meta_graph_def;
+
+ const string export_dir =
+ io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
+ TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
+ &meta_graph_def));
+ CheckMetaGraphDef(meta_graph_def);
+}
+
+TEST_F(ReaderTest, NoTagMatch) {
+ MetaGraphDef meta_graph_def;
+
+ const string export_dir =
+ io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
+ Status st = ReadMetaGraphDefFromSavedModel(export_dir, {"missing-tag"},
+ &meta_graph_def);
+ EXPECT_FALSE(st.ok());
+ EXPECT_TRUE(str_util::StrContains(
+ st.error_message(),
+ "Could not find meta graph def matching supplied tags: { missing-tag }"))
+ << st.error_message();
+}
+
+TEST_F(ReaderTest, NoTagMatchMultiple) {
+ MetaGraphDef meta_graph_def;
+
+ const string export_dir =
+ io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
+ Status st = ReadMetaGraphDefFromSavedModel(
+ export_dir, {kSavedModelTagServe, "missing-tag"}, &meta_graph_def);
+ EXPECT_FALSE(st.ok());
+ EXPECT_TRUE(str_util::StrContains(
+ st.error_message(),
+ "Could not find meta graph def matching supplied tags: "))
+ << st.error_message();
+}
+
+TEST_F(ReaderTest, PbtxtFormat) {
+ MetaGraphDef meta_graph_def;
+
+ const string export_dir =
+ io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPbTxt);
+ TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
+ &meta_graph_def));
+ CheckMetaGraphDef(meta_graph_def);
+}
+
+TEST_F(ReaderTest, InvalidExportPath) {
+ MetaGraphDef meta_graph_def;
+
+ const string export_dir =
+ io::JoinPath(testing::TensorFlowSrcRoot(), "missing-path");
+ Status st = ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
+ &meta_graph_def);
+ EXPECT_FALSE(st.ok());
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
index b3a1c19c9e..9c424b201e 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
@@ -60,9 +60,9 @@ const char* const kXlaHostTransferSequencerAttr =
namespace {
-bool AreAllParentsConst(const Node& n,
- const gtl::FlatSet<const Node*>& runtime_const_nodes) {
- if (n.type_string() == "GuaranteeConst" || n.type_string() == "Const") {
+bool AreAllParentsGuaranteedConst(
+ const Node& n, const gtl::FlatSet<const Node*>& runtime_const_nodes) {
+ if (n.type_string() == "GuaranteeConst") {
// If the current node is itself a cast-to-const, no need
// to look at the incoming edges.
return true;
@@ -93,7 +93,8 @@ void MarkGuaranteedConstants(
ReverseDFSFrom(graph, srcs, /*enter=*/nullptr,
/*leave=*/[&guaranteed_const_nodes](const Node* n) {
// TODO(vinuraja): Doesn't work in the presence of loops.
- if (AreAllParentsConst(*n, guaranteed_const_nodes)) {
+ if (AreAllParentsGuaranteedConst(*n,
+ guaranteed_const_nodes)) {
guaranteed_const_nodes.insert(n);
}
});
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
index 4eb389e0c6..c0543a0079 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
@@ -742,10 +742,13 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) {
Scope root = Scope::NewRootScope().ExitOnError().WithDevice(
"/job:localhost/replica:0/task:0/cpu:0");
auto x1 = ops::Placeholder(root.WithOpName("x1"), DT_FLOAT);
- auto const_x2 = ops::Const(root.WithOpName("const_x2"), 10.0f);
+ auto x2 = ops::Placeholder(root.WithOpName("x2"), DT_FLOAT);
+ auto const_guarantee_x2 =
+ ops::GuaranteeConst(root.WithOpName("const_guarantee_x2"), x2);
auto const_guarantee_x1 =
ops::GuaranteeConst(root.WithOpName("const_guarantee_x1"), x1);
- auto add1 = ops::Add(root.WithOpName("add1"), const_guarantee_x1, const_x2);
+ auto add1 =
+ ops::Add(root.WithOpName("add1"), const_guarantee_x1, const_guarantee_x2);
add1.node()->AddAttr("_encapsulate", "encapsulate1");
Graph graph_before(OpRegistry::Global());
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
index 251a07304e..338fb5a6f0 100644
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc
+++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
@@ -115,6 +115,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
const XlaDevice::Metadata* metadata = nullptr;
Status s = XlaDevice::GetMetadata(ctx, &metadata);
bool allocate_xla_tensors = s.ok();
+ bool use_multiple_streams = s.ok() && metadata->UseMultipleStreams();
// Get the platform_id_ for XLA_* devices.
if (platform_id_ == nullptr) {
@@ -180,8 +181,8 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
VLOG(1) << "Executing XLA Computation...";
- XlaComputationLaunchContext launch_context(client, xla_allocator,
- allocate_xla_tensors);
+ XlaComputationLaunchContext launch_context(
+ client, xla_allocator, allocate_xla_tensors, use_multiple_streams);
launch_context.PopulateInputs(ctx, kernel, variables);
// Execute the computation.
diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc
index 54a41a4daa..7ed609c437 100644
--- a/tensorflow/compiler/jit/xla_compilation_cache.cc
+++ b/tensorflow/compiler/jit/xla_compilation_cache.cc
@@ -40,23 +40,7 @@ namespace tensorflow {
XlaCompilationCache::XlaCompilationCache(xla::LocalClient* client,
DeviceType device_type)
: client_(client), device_type_(std::move(device_type)) {}
-XlaCompilationCache::~XlaCompilationCache() {
- // Ensure any use of our programs have completed by waiting for all stream
- // executors to complete.
- for (auto* executor : client_->backend().stream_executors()) {
- bool ok = executor->SynchronizeAllActivity();
- if (!ok) {
- LOG(ERROR) << "Error synchronizing activity while waiting for all "
- "programs to complete";
- }
- }
- // TODO(b/110813685): Think about the program ownership model. Programs are
- // currently owned by the compilation cache which means we must wait for
- // program completion in the destructor. There are multiple compilation caches
- // around, which complicates things a little. Perhaps having programs be
- // shared_ptrs (an invasive change) would make the model easier to reason
- // about?
-}
+XlaCompilationCache::~XlaCompilationCache() = default;
string XlaCompilationCache::DebugString() {
return "XLA JIT compilation cache";
diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
index baccea2d6a..d288d37bc7 100644
--- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
+++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
@@ -53,7 +53,9 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
// Builds an XLA allocator for the device.
XlaComputationLaunchContext launch_context(
- client, client->backend().memory_allocator(), true);
+ client, client->backend().memory_allocator(),
+ /*allocate_xla_tensors=*/true,
+ /*use_multiple_streams=*/metadata.UseMultipleStreams());
launch_context.PopulateInputs(ctx, result, variables);
diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc
index 43648402f6..7e159e3171 100644
--- a/tensorflow/compiler/jit/xla_cpu_device.cc
+++ b/tensorflow/compiler/jit/xla_cpu_device.cc
@@ -54,6 +54,7 @@ Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& options,
DEVICE_CPU_XLA_JIT, options, name_prefix,
registration,
/*transfer_as_literal=*/false,
+ /*use_multiple_streams=*/false,
/*shape_representation_fn=*/{},
/*padded_shape_fn=*/{}, &device));
devices->push_back(device.release());
diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc
index ed007d603e..c55eba2f79 100644
--- a/tensorflow/compiler/jit/xla_device.cc
+++ b/tensorflow/compiler/jit/xla_device.cc
@@ -130,7 +130,7 @@ Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) {
const string& jit_device_name, const SessionOptions& options,
const string& name_prefix,
const XlaOpRegistry::DeviceRegistration& registration,
- bool transfer_as_literal,
+ bool transfer_as_literal, bool use_multiple_streams,
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
const PaddedShapeFn& padded_shape_fn, std::unique_ptr<XlaDevice>* device) {
VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":"
@@ -151,22 +151,24 @@ Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) {
DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(),
strings::StrCat("device: ", device_name, " device"));
- device->reset(new XlaDevice(
- options, attrs, device_ordinal, DeviceType(jit_device_name),
- platform.ValueOrDie(), transfer_as_literal, shape_representation_fn,
- padded_shape_fn ? padded_shape_fn : DefaultPaddedShapeFn));
+ device->reset(
+ new XlaDevice(options, attrs, device_ordinal, DeviceType(jit_device_name),
+ platform.ValueOrDie(), transfer_as_literal,
+ use_multiple_streams, shape_representation_fn,
+ padded_shape_fn ? padded_shape_fn : DefaultPaddedShapeFn));
return Status::OK();
}
XlaDevice::Metadata::Metadata(
int device_ordinal, se::Platform* platform, const DeviceType& device_type,
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
- PaddedShapeFn padded_shape_fn)
+ PaddedShapeFn padded_shape_fn, bool use_multiple_streams)
: device_ordinal_(device_ordinal),
device_type_(device_type),
platform_(platform),
shape_representation_fn_(std::move(shape_representation_fn)),
- padded_shape_fn_(std::move(padded_shape_fn)) {}
+ padded_shape_fn_(std::move(padded_shape_fn)),
+ use_multiple_streams_(use_multiple_streams) {}
int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; }
@@ -200,16 +202,18 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const {
XlaDevice::XlaDevice(
const SessionOptions& options, const DeviceAttributes& attrs,
int device_ordinal, const DeviceType& jit_device_name,
- se::Platform* platform, bool transfer_as_literal,
+ se::Platform* platform, bool transfer_as_literal, bool use_multiple_streams,
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
const PaddedShapeFn& padded_shape_fn)
: LocalDevice(options, attrs),
xla_metadata_(device_ordinal, platform, jit_device_name,
- shape_representation_fn, padded_shape_fn),
+ shape_representation_fn, padded_shape_fn,
+ use_multiple_streams),
device_ordinal_(device_ordinal),
jit_device_name_(jit_device_name),
xla_allocator_(nullptr),
platform_(platform),
+ use_multiple_streams_(use_multiple_streams),
transfer_as_literal_(transfer_as_literal),
shape_representation_fn_(shape_representation_fn) {
VLOG(1) << "Created XLA device " << jit_device_name;
@@ -253,6 +257,30 @@ xla::StatusOr<se::Stream*> XlaDevice::GetStream() {
return stream_.get();
}
+xla::StatusOr<se::Stream*> XlaDevice::GetDeviceToHostStream() {
+ if (!use_multiple_streams_) {
+ return GetStream();
+ }
+ if (!device_to_host_stream_) {
+ xla::Backend* backend = client()->mutable_backend();
+ TF_ASSIGN_OR_RETURN(device_to_host_stream_,
+ backend->BorrowStream(device_ordinal_));
+ }
+ return device_to_host_stream_.get();
+}
+
+xla::StatusOr<se::Stream*> XlaDevice::GetHostToDeviceStream() {
+ if (!use_multiple_streams_) {
+ return GetStream();
+ }
+ if (!host_to_device_stream_) {
+ xla::Backend* backend = client()->mutable_backend();
+ TF_ASSIGN_OR_RETURN(host_to_device_stream_,
+ backend->BorrowStream(device_ordinal_));
+ }
+ return host_to_device_stream_.get();
+}
+
Status XlaDevice::CreateAndSetGpuDeviceInfo() {
if (gpu_device_info_ == nullptr) {
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
@@ -263,8 +291,9 @@ Status XlaDevice::CreateAndSetGpuDeviceInfo() {
// gpu_device_info_->default_context.
gpu_device_info_ = MakeUnique<GpuDeviceInfo>();
gpu_device_info_->stream = stream;
- gpu_device_info_->default_context = new XlaDeviceContext(
- stream, client(), transfer_as_literal_, shape_representation_fn_);
+ gpu_device_info_->default_context =
+ new XlaDeviceContext(stream, stream, stream, client(),
+ transfer_as_literal_, shape_representation_fn_);
set_tensorflow_gpu_device_info(gpu_device_info_.get());
}
@@ -276,10 +305,16 @@ Status XlaDevice::FillContextMap(const Graph* graph,
VLOG(1) << "XlaDevice::FillContextMap";
device_context_map->resize(graph->num_node_ids());
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
+ TF_ASSIGN_OR_RETURN(se::Stream * device_to_host_stream,
+ GetDeviceToHostStream());
+ TF_ASSIGN_OR_RETURN(se::Stream * host_to_device_stream,
+ GetHostToDeviceStream());
+
// Call GetAllocator for the side-effect of ensuring the allocator is created.
GetAllocator({});
- auto ctx = new XlaDeviceContext(stream, client(), transfer_as_literal_,
- shape_representation_fn_);
+ auto ctx = new XlaDeviceContext(
+ stream, host_to_device_stream, device_to_host_stream, client(),
+ transfer_as_literal_, shape_representation_fn_);
for (Node* n : graph->nodes()) {
VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name();
ctx->Ref();
@@ -326,8 +361,13 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape());
Notification n;
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
- XlaTransferManager manager(stream, client(), transfer_as_literal_,
- shape_representation_fn_);
+ TF_ASSIGN_OR_RETURN(se::Stream * device_to_host_stream,
+ GetDeviceToHostStream());
+ TF_ASSIGN_OR_RETURN(se::Stream * host_to_device_stream,
+ GetHostToDeviceStream());
+ XlaTransferManager manager(stream, host_to_device_stream,
+ device_to_host_stream, client(),
+ transfer_as_literal_, shape_representation_fn_);
manager.CopyCPUTensorToDevice(&parsed, this, &copy,
[&n, &status](const Status& s) {
status = s;
diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h
index 02e88ee679..fccdb14368 100644
--- a/tensorflow/compiler/jit/xla_device.h
+++ b/tensorflow/compiler/jit/xla_device.h
@@ -57,7 +57,7 @@ class XlaDevice : public LocalDevice {
Metadata(int device_ordinal, se::Platform* platform,
const DeviceType& device_type,
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
- PaddedShapeFn padded_shape_fn);
+ PaddedShapeFn padded_shape_fn, bool use_multiple_streams);
// The index of the device on this host.
int device_ordinal() const;
@@ -70,12 +70,15 @@ class XlaDevice : public LocalDevice {
}
const PaddedShapeFn& padded_shape_fn() const { return padded_shape_fn_; }
+ bool UseMultipleStreams() const { return use_multiple_streams_; }
+
private:
const int device_ordinal_;
const DeviceType device_type_;
se::Platform* platform_; // Not owned.
XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
PaddedShapeFn padded_shape_fn_;
+ const bool use_multiple_streams_;
TF_DISALLOW_COPY_AND_ASSIGN(Metadata);
};
@@ -89,6 +92,8 @@ class XlaDevice : public LocalDevice {
// 'transfer_as_literal' is true if device<->host transfers must be done using
// XLA's TransferLiteral{To,From}Device interface. If false, we can use
// ThenMemcpy instead.
+ // If 'use_multiple_streams' is true, we create separate streams for
+ // host-to-device and device-to-host communication.
// If padded_shape_fn is empty, a default implementation that returns
// the on-host shape is used.
static Status Create(
@@ -96,7 +101,7 @@ class XlaDevice : public LocalDevice {
int device_ordinal, const string& jit_device_name,
const SessionOptions& options, const string& name_prefix,
const XlaOpRegistry::DeviceRegistration& registration,
- bool transfer_as_literal,
+ bool transfer_as_literal, bool use_multiple_streams,
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
const PaddedShapeFn& padded_shape_fn, std::unique_ptr<XlaDevice>* device);
@@ -106,6 +111,7 @@ class XlaDevice : public LocalDevice {
XlaDevice(const SessionOptions& options, const DeviceAttributes& attrs,
int device_ordinal, const DeviceType& jit_device_name,
se::Platform* platform, bool transfer_as_literal,
+ bool use_multiple_streams,
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
const PaddedShapeFn& padded_shape_fn);
~XlaDevice() override;
@@ -126,6 +132,8 @@ class XlaDevice : public LocalDevice {
xla::LocalClient* client() const;
const Metadata& metadata() { return xla_metadata_; }
xla::StatusOr<se::Stream*> GetStream();
+ xla::StatusOr<se::Stream*> GetHostToDeviceStream();
+ xla::StatusOr<se::Stream*> GetDeviceToHostStream();
// If not already set, create and set GpuDeviceInfo.
// Not thread-safe
@@ -146,6 +154,16 @@ class XlaDevice : public LocalDevice {
// copying back and forth between CPU and the device, and
// computations enqueued by XLA.
xla::Backend::StreamPtr stream_;
+ // If true, only stream_ is valid and all computation and transfers use
+ // stream_. If false, computation is performed by stream_ and transfers are
+ // performed by host_to_device/device_to_host_stream.
+ bool use_multiple_streams_;
+ // If use_multiple_streams_, host to device transfers are performed using this
+ // stream.
+ xla::Backend::StreamPtr host_to_device_stream_;
+ // If use_multiple_streams_, device to host transfers are performed using this
+ // stream.
+ xla::Backend::StreamPtr device_to_host_stream_;
// Must we use XLA's transfer manager for correct host<->device transfers? if
// false, we can use ThenMemcpy() instead.
bool transfer_as_literal_;
diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc
index 3bbf97afad..04778c0090 100644
--- a/tensorflow/compiler/jit/xla_device_context.cc
+++ b/tensorflow/compiler/jit/xla_device_context.cc
@@ -48,13 +48,20 @@ void XlaDeviceAllocator::DeallocateRaw(void* ptr) {
void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); }
XlaTransferManager::XlaTransferManager(
- se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
+ se::Stream* compute_stream, se::Stream* host_to_device_stream,
+ se::Stream* device_to_host_stream, xla::LocalClient* client,
+ bool transfer_as_literal,
XlaCompiler::ShapeRepresentationFn shape_representation_fn)
- : stream_(stream),
+ : stream_(compute_stream),
+ host_to_device_stream_(host_to_device_stream),
+ device_to_host_stream_(device_to_host_stream),
client_(client),
transfer_manager_(client->backend().transfer_manager()),
transfer_as_literal_(transfer_as_literal),
shape_representation_fn_(std::move(shape_representation_fn)) {
+ CHECK(host_to_device_stream_ != nullptr);
+ CHECK(device_to_host_stream_ != nullptr);
+ CHECK(stream_ != nullptr);
if (!shape_representation_fn_) {
shape_representation_fn_ =
[](const TensorShape& shape,
@@ -67,120 +74,110 @@ Status XlaTransferManager::TransferLiteralToDevice(
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor.dtype(),
host_tensor.shape(), &xla_shape));
- // Create a reference to hold onto host_tensor until after the literal has
- // been transferred. Also make sure the literal exists until the function
- // asynchronously completes, as it will be wrapped in an xla::LiteralSlice.
- TensorReference ref(host_tensor);
- auto literal = std::make_shared<xla::BorrowingLiteral>(
+ xla::BorrowingLiteral literal(
static_cast<const char*>(DMAHelper::base(&host_tensor)), xla_shape);
- const xla::ShapedBuffer& shaped_buffer =
- XlaTensor::FromTensor(device_tensor)->shaped_buffer();
- VLOG(1) << "Transfer to device as literal: " << literal->ToString() << " "
+ XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
+ const xla::ShapedBuffer& shaped_buffer = xla_tensor->shaped_buffer();
+ VLOG(1) << "Transfer to device as literal: " << literal.ToString() << " "
<< shaped_buffer.ToString();
- TF_RETURN_IF_ERROR(transfer_manager_->TransferLiteralToDeviceAsync(
- stream_, *literal, shaped_buffer));
- // Unref the host tensor, and capture the literal shared_ptr too so it goes
- // out of scope when the lambda completes.
- stream_->ThenDoHostCallback([ref, literal]() { ref.Unref(); });
+ TF_RETURN_IF_ERROR(transfer_manager_->TransferLiteralToDevice(
+ host_to_device_stream_, literal, shaped_buffer));
+ if (UseMultipleStreams()) {
+ se::Event event(stream_->parent());
+ TF_RET_CHECK(event.Init()) << "Event failed to initialize!";
+ host_to_device_stream_->ThenRecordEvent(&event);
+ xla_tensor->SetDefinedOn(host_to_device_stream_, std::move(event));
+ }
return Status::OK();
}
-void XlaTransferManager::TransferLiteralFromDevice(
- Tensor* host_tensor, const Tensor& device_tensor,
- const StatusCallback& done) const {
+Status XlaTransferManager::TransferLiteralFromDevice(
+ Tensor* host_tensor, const Tensor& device_tensor) const {
const xla::ShapedBuffer& shaped_buffer =
XlaTensor::FromTensor(&device_tensor)->shaped_buffer();
- TensorReference ref(device_tensor);
- transfer_manager_->TransferLiteralFromDevice(
- stream_, shaped_buffer,
- [=, &shaped_buffer](
- xla::StatusOr<std::unique_ptr<xla::Literal> > literal_or) {
- ref.Unref();
- done([&]() -> Status {
- TF_ASSIGN_OR_RETURN(auto literal, std::move(literal_or));
- VLOG(1) << "Transfer from device as literal: " << literal->ToString()
- << " " << shaped_buffer.ToString();
- Tensor tensor;
- TF_RETURN_IF_ERROR(
- LiteralToHostTensor(*literal, host_tensor->dtype(), &tensor));
- // Reshape the tensor back to its declared shape.
- Status status;
- if (!host_tensor->CopyFrom(tensor, device_tensor.shape())) {
- status = errors::Internal(
- "Tensor::CopyFrom failed when copying from XLA device to CPU");
- }
- return status;
- }());
- });
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Literal> literal,
+ transfer_manager_->TransferLiteralFromDevice(
+ device_to_host_stream_, shaped_buffer));
+ VLOG(1) << "Transfer from device as literal: " << literal->ToString() << " "
+ << shaped_buffer.ToString();
+ Tensor tensor;
+ TF_RETURN_IF_ERROR(
+ LiteralToHostTensor(*literal, host_tensor->dtype(), &tensor));
+ // Reshape the tensor back to its declared shape.
+ if (!host_tensor->CopyFrom(tensor, device_tensor.shape())) {
+ return errors::Internal(
+ "Tensor::CopyFrom failed when copying from XLA device to CPU");
+ }
+ return Status::OK();
}
void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
Device* device,
Tensor* device_tensor,
StatusCallback done) const {
- if (cpu_tensor->NumElements() > 0) {
- VLOG(2) << "CopyCPUTensorToDevice "
- << reinterpret_cast<const void*>(cpu_tensor->tensor_data().data())
- << " "
- << reinterpret_cast<const void*>(
- device_tensor->tensor_data().data())
- << " " << cpu_tensor->NumElements() << " "
- << cpu_tensor->shape().DebugString() << " "
- << device_tensor->shape().DebugString();
-
- void* src_ptr = const_cast<void*>(DMAHelper::base(cpu_tensor));
- const int64 total_bytes = cpu_tensor->TotalBytes();
-
- XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
- CHECK(xla_tensor);
-
- Status status;
- xla::StatusOr<TensorShape> shape_or_status = shape_representation_fn_(
- device_tensor->shape(), device_tensor->dtype());
- if (!shape_or_status.ok()) {
- done(shape_or_status.status());
+ if (cpu_tensor->NumElements() == 0) {
+ VLOG(2) << "CopyCPUTensorToDevice empty tensor";
+ done(Status::OK());
+ return;
+ }
+
+ VLOG(2) << "CopyCPUTensorToDevice "
+ << reinterpret_cast<const void*>(cpu_tensor->tensor_data().data())
+ << " "
+ << reinterpret_cast<const void*>(device_tensor->tensor_data().data())
+ << " " << cpu_tensor->NumElements() << " "
+ << cpu_tensor->shape().DebugString() << " "
+ << device_tensor->shape().DebugString();
+
+ void* src_ptr = const_cast<void*>(DMAHelper::base(cpu_tensor));
+ const int64 total_bytes = cpu_tensor->TotalBytes();
+
+ XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
+ CHECK(xla_tensor);
+
+ xla::StatusOr<TensorShape> shape_or_status =
+ shape_representation_fn_(device_tensor->shape(), device_tensor->dtype());
+ if (!shape_or_status.ok()) {
+ done(shape_or_status.status());
+ return;
+ }
+ TensorShape shape = shape_or_status.ValueOrDie();
+ if (!xla_tensor->has_shaped_buffer()) {
+ Status s =
+ xla_tensor->AllocateShapedBuffer(device_tensor->dtype(), shape, client_,
+ stream_->parent()->device_ordinal());
+ if (!s.ok()) {
+ done(s);
return;
}
- TensorShape shape = shape_or_status.ValueOrDie();
- if (!xla_tensor->has_shaped_buffer()) {
- status = xla_tensor->AllocateShapedBuffer(
- device_tensor->dtype(), shape, client_,
- stream_->parent()->device_ordinal());
- if (!status.ok()) {
- return done(status);
- }
- }
+ }
- if (transfer_as_literal_) {
- Tensor reshaped_cpu_tensor;
- if (!reshaped_cpu_tensor.CopyFrom(*cpu_tensor, shape)) {
- done(errors::Internal(
- "Tensor::CopyFrom failed when copying from CPU to XLA device"));
- return;
- }
- status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor);
- } else {
- se::DeviceMemoryBase dev_dst_ptr =
- XlaTensor::DeviceMemoryFromTensor(*device_tensor);
- stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes);
- // TODO(hpucha): Make this asynchronous.
- Status block_status = stream_->BlockHostUntilDone();
- if (!block_status.ok()) {
- status = xla::InternalError(
- "Failed to complete data transfer on stream %p: %s", stream_,
- block_status.error_message().c_str());
- }
+ Status status;
+ if (transfer_as_literal_) {
+ Tensor reshaped_cpu_tensor;
+ if (!reshaped_cpu_tensor.CopyFrom(*cpu_tensor, shape)) {
+ done(errors::Internal(
+ "Tensor::CopyFrom failed when copying from CPU to XLA device"));
+ return;
+ }
+ status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor);
+ } else {
+ se::DeviceMemoryBase dev_dst_ptr =
+ XlaTensor::DeviceMemoryFromTensor(*device_tensor);
+ host_to_device_stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes);
+ // TODO(hpucha): Make this asynchronous.
+ Status block_status = host_to_device_stream_->BlockHostUntilDone();
+ if (!block_status.ok()) {
+ status = xla::InternalError(
+ "Failed to complete data transfer on stream %p: %s",
+ host_to_device_stream_, block_status.error_message().c_str());
}
- xla_tensor->set_host_tensor(*cpu_tensor);
-
- done(status);
- return;
}
+ xla_tensor->set_host_tensor(*cpu_tensor);
- VLOG(2) << "CopyCPUTensorToDevice empty tensor";
- done(Status::OK());
+ done(status);
}
void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
@@ -188,51 +185,64 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
Device* device,
Tensor* cpu_tensor,
StatusCallback done) {
- if (device_tensor->NumElements() > 0) {
- VLOG(2) << "CopyDeviceTensorToCPU "
- << reinterpret_cast<const void*>(
- device_tensor->tensor_data().data())
- << " "
- << reinterpret_cast<const void*>(cpu_tensor->tensor_data().data())
- << " " << device_tensor->NumElements() << " "
- << cpu_tensor->shape().DebugString() << " "
- << device_tensor->shape().DebugString();
-
- const int64 total_bytes = cpu_tensor->TotalBytes();
- se::DeviceMemoryBase dev_src_ptr =
- XlaTensor::DeviceMemoryFromTensor(*device_tensor);
- void* dst_ptr = DMAHelper::base(cpu_tensor);
+ if (device_tensor->NumElements() == 0) {
+ VLOG(2) << "CopyDeviceTensorToCPU empty tensor";
+ done(Status::OK());
+ return;
+ }
+ VLOG(2) << "CopyDeviceTensorToCPU "
+ << reinterpret_cast<const void*>(device_tensor->tensor_data().data())
+ << " "
+ << reinterpret_cast<const void*>(cpu_tensor->tensor_data().data())
+ << " " << device_tensor->NumElements() << " "
+ << cpu_tensor->shape().DebugString() << " "
+ << device_tensor->shape().DebugString();
+
+ const int64 total_bytes = cpu_tensor->TotalBytes();
+ se::DeviceMemoryBase dev_src_ptr =
+ XlaTensor::DeviceMemoryFromTensor(*device_tensor);
+ void* dst_ptr = DMAHelper::base(cpu_tensor);
+ XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
+
+ if (se::Event* event =
+ xla_tensor->GetDefinitionEvent(device_to_host_stream_)) {
+ device_to_host_stream_->ThenWaitFor(event);
+ xla_tensor->SetDefinedOn(device_to_host_stream_);
+ }
- Status status;
- if (transfer_as_literal_) {
- TransferLiteralFromDevice(cpu_tensor, *device_tensor, done);
- return;
- } else {
- stream_->ThenMemcpy(dst_ptr, dev_src_ptr, total_bytes);
- // TODO(hpucha): Make this asynchronous.
- Status block_status = stream_->BlockHostUntilDone();
- if (!block_status.ok()) {
- status = xla::InternalError(
- "Failed to complete data transfer on stream %p: %s", stream_,
- block_status.error_message().c_str());
- }
- done(status);
+ Status status;
+ if (transfer_as_literal_) {
+ status = TransferLiteralFromDevice(cpu_tensor, *device_tensor);
+ } else {
+ device_to_host_stream_->ThenMemcpy(dst_ptr, dev_src_ptr, total_bytes);
+ // TODO(hpucha): Make this asynchronous.
+ Status block_status = device_to_host_stream_->BlockHostUntilDone();
+ if (!block_status.ok()) {
+ status = xla::InternalError(
+ "Failed to complete data transfer on stream %p: %s", stream_,
+ block_status.error_message().c_str());
}
- return;
}
- VLOG(2) << "CopyDeviceTensorToCPU empty tensor";
- done(Status::OK());
+ done(status);
}
void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor,
Tensor* dst_tensor,
const StatusCallback& done) {
- // Perform memory allocation now, and enqueue the device-to-device transfer.
- Status status = [&]() -> Status {
+ VLOG(2) << "CopyDeviceTensorToDevice "
+ << reinterpret_cast<const void*>(src_tensor.tensor_data().data())
+ << " "
+ << reinterpret_cast<const void*>(dst_tensor->tensor_data().data());
+ // TODO(phawkins): replace this code with an asynchronous implementation.
+ auto body = [&]() {
if (src_tensor.NumElements() == 0) {
return Status::OK();
}
+ // TODO(jmolloy): We co-opt the device_to_host stream for device to device
+ // transfers; perhaps we should have a dedicated device to device stream? or
+ // one per device?
+ auto device_to_device_stream = device_to_host_stream_;
XlaTensor* xla_src = XlaTensor::FromTensor(&src_tensor);
XlaTensor* xla_dst = XlaTensor::FromTensor(dst_tensor);
CHECK(xla_src && xla_dst)
@@ -245,26 +255,37 @@ void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor,
xla_dst->AllocateShapedBuffer(src_tensor.dtype(), shape, client_,
stream_->parent()->device_ordinal()));
}
- auto from_iter = xla_src->shaped_buffer().buffers().begin();
- auto to_iter = xla_dst->shaped_buffer().buffers().begin();
- for (auto end_iter = xla_src->shaped_buffer().buffers().end();
- from_iter != end_iter; ++from_iter, ++to_iter) {
- stream_->ThenMemcpyD2D(&to_iter->second, from_iter->second,
- to_iter->second.size());
+
+ if (se::Event* event =
+ xla_src->GetDefinitionEvent(device_to_device_stream)) {
+ device_to_device_stream->ThenWaitFor(event);
+ xla_src->SetDefinedOn(device_to_device_stream);
+ TF_RETURN_IF_ERROR(device_to_device_stream->BlockHostUntilDone());
}
+ TF_RETURN_IF_ERROR(
+ xla_dst->shaped_buffer().buffers().ForEachMutableElementWithStatus(
+ [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
+ const se::DeviceMemoryBase& from_buffer =
+ xla_src->shaped_buffer().buffers().element(index);
+ CHECK_EQ(buffer->size(), from_buffer.size());
+ if (!stream_->parent()->SynchronousMemcpy(buffer, from_buffer,
+ buffer->size())) {
+ return errors::Internal("Device to device memcpy failed");
+ }
+ return Status::OK();
+ }));
return Status::OK();
- }();
- if (!status.ok()) {
- return done(status);
- } else {
- stream_->ThenDoHostCallback([=]() { done(Status::OK()); });
- }
+ };
+ done(body());
}
XlaDeviceContext::XlaDeviceContext(
- se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
+ se::Stream* compute_stream, se::Stream* host_to_device_stream,
+ se::Stream* device_to_host_stream, xla::LocalClient* client,
+ bool transfer_as_literal,
XlaCompiler::ShapeRepresentationFn shape_representation_fn)
- : manager_(stream, client, transfer_as_literal,
+ : manager_(compute_stream, host_to_device_stream, device_to_host_stream,
+ client, transfer_as_literal,
std::move(shape_representation_fn)) {}
void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h
index c5c81d65fe..c726495f96 100644
--- a/tensorflow/compiler/jit/xla_device_context.h
+++ b/tensorflow/compiler/jit/xla_device_context.h
@@ -47,7 +47,9 @@ class XlaDeviceAllocator : public Allocator {
class XlaTransferManager {
public:
explicit XlaTransferManager(
- se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
+ se::Stream* compute_stream, se::Stream* host_to_device_stream,
+ se::Stream* device_to_host_stream, xla::LocalClient* client,
+ bool transfer_as_literal,
XlaCompiler::ShapeRepresentationFn shape_representation_fn);
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
@@ -64,13 +66,19 @@ class XlaTransferManager {
private:
Status TransferLiteralToDevice(const Tensor& host_tensor,
Tensor* device_tensor) const;
- void TransferLiteralFromDevice(Tensor* host_tensor,
- const Tensor& device_tensor,
- const StatusCallback& done) const;
+ Status TransferLiteralFromDevice(Tensor* host_tensor,
+ const Tensor& device_tensor) const;
+ bool UseMultipleStreams() const { return stream_ != host_to_device_stream_; }
- // Stream obtained from a Device, used to transfer tensors between
- // CPU and device.
+ // The main compute stream of the device, used to synchronize the transfer
+ // streams if they are set.
se::Stream* stream_;
+ // The stream to use for transferring data from host to device. Can be
+ // idential to stream_, but must not be nullptr.
+ se::Stream* host_to_device_stream_;
+ // The stream to use for transferring data from device to host. Can be
+ // idential to stream_, but must not be nullptr.
+ se::Stream* device_to_host_stream_;
// For the underlying memory allocator and XLA's TransferManager.
xla::LocalClient* client_;
// Transfer manager, for marshalling data to and from the device.
@@ -86,7 +94,9 @@ class XlaTransferManager {
class XlaDeviceContext : public DeviceContext {
public:
explicit XlaDeviceContext(
- se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
+ se::Stream* compute_stream, se::Stream* host_to_device_stream,
+ se::Stream* device_to_host_stream, xla::LocalClient* client,
+ bool transfer_as_literal,
XlaCompiler::ShapeRepresentationFn shape_representation_fn);
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h
index a605335a94..134dcc1bb5 100644
--- a/tensorflow/compiler/jit/xla_device_ops.h
+++ b/tensorflow/compiler/jit/xla_device_ops.h
@@ -90,6 +90,9 @@ class XlaAssignVariableOp : public AsyncOpKernel {
REGISTER_KERNEL_BUILDER( \
Name("ReadVariableOp").Device(DEVICE).HostMemory("resource"), \
ReadVariableOp); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("DestroyResourceOp").Device(DEVICE).HostMemory("resource"), \
+ DestroyResourceOp); \
REGISTER_KERNEL_BUILDER(Name("Shape") \
.Device(DEVICE) \
.HostMemory("output") \
diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc
index c0d86a28c7..851b118b0c 100644
--- a/tensorflow/compiler/jit/xla_gpu_device.cc
+++ b/tensorflow/compiler/jit/xla_gpu_device.cc
@@ -49,6 +49,7 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options,
XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options,
name_prefix, registration,
/*transfer_as_literal=*/false,
+ /*use_multiple_streams=*/false,
/*shape_representation_fn=*/{},
/*padded_shape_fn=*/{}, &device);
if (!status.ok()) {
diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc
index 661187f4a8..4574559674 100644
--- a/tensorflow/compiler/jit/xla_interpreter_device.cc
+++ b/tensorflow/compiler/jit/xla_interpreter_device.cc
@@ -52,6 +52,7 @@ Status XlaInterpreterDeviceFactory::CreateDevices(
DEVICE_INTERPRETER_XLA_JIT, options,
name_prefix, registration,
/*transfer_as_literal=*/false,
+ /*use_multiple_streams=*/false,
/*shape_representation_fn=*/{},
/*padded_shape_fn=*/{}, &device));
devices->push_back(device.release());
diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc
index 5ceccc769f..616c3ed2a2 100644
--- a/tensorflow/compiler/jit/xla_launch_util.cc
+++ b/tensorflow/compiler/jit/xla_launch_util.cc
@@ -115,14 +115,22 @@ using internal::ExtractSubShapedBuffer;
XlaComputationLaunchContext::XlaComputationLaunchContext(
xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator,
- bool allocate_xla_tensors)
+ bool allocate_xla_tensors, bool use_multiple_streams)
: client_(client),
xla_allocator_(xla_allocator),
- allocate_xla_tensors_(allocate_xla_tensors) {}
+ allocate_xla_tensors_(allocate_xla_tensors),
+ use_multiple_streams_(use_multiple_streams) {
+ if (use_multiple_streams_) {
+ CHECK(allocate_xla_tensors_) << "To use multiple streams correctly we must "
+ "be allocating XLA tensors!";
+ }
+}
void XlaComputationLaunchContext::PopulateInputs(
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
const std::map<int, OptionalTensor>& variables) {
+ se::Stream* stream =
+ ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
// Build ShapedBuffers that point directly to the Tensor buffers.
arg_buffers_.reserve(kernel->xla_input_shapes.size() + 1);
arg_buffers_.resize(kernel->xla_input_shapes.size());
@@ -140,6 +148,16 @@ void XlaComputationLaunchContext::PopulateInputs(
t = &(ctx->input(arg_num));
}
+ if (use_multiple_streams_) {
+ CHECK(stream) << "Must have a stream available when using XLA tensors!";
+ XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
+ CHECK(xla_tensor);
+ if (se::Event* event = xla_tensor->GetDefinitionEvent(stream)) {
+ stream->ThenWaitFor(event);
+ xla_tensor->SetDefinedOn(stream);
+ }
+ }
+
const xla::Shape on_device_shape =
client_->backend().transfer_manager()->HostShapeToDeviceShape(shape);
if (xla::ShapeUtil::IsTuple(on_device_shape)) {
@@ -248,6 +266,12 @@ void XlaComputationLaunchContext::PopulateOutputs(
if (xla_tensor) {
xla_tensor->set_shaped_buffer(ScopedShapedBuffer(
ExtractSubShapedBuffer(&output, output_num, xla_allocator_)));
+ if (use_multiple_streams_) {
+ se::Event event(stream->parent());
+ CHECK(event.Init());
+ stream->ThenRecordEvent(&event);
+ xla_tensor->SetDefinedOn(stream, std::move(event));
+ }
} else {
// xla_tensor wasn't valid, which must mean this is a zero-element
// tensor.
@@ -302,6 +326,12 @@ void XlaComputationLaunchContext::PopulateOutputs(
CHECK(xla_tensor);
xla_tensor->set_shaped_buffer(
ExtractSubShapedBuffer(&output, output_num, xla_allocator_));
+ if (use_multiple_streams_) {
+ se::Event event(stream->parent());
+ CHECK(event.Init());
+ stream->ThenRecordEvent(&event);
+ xla_tensor->SetDefinedOn(stream, std::move(event));
+ }
*variable->tensor() = output_tensor;
} else {
Tensor output_tensor = XlaTensorBuffer::MakeTensor(
diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h
index 4390701ccb..90531174ff 100644
--- a/tensorflow/compiler/jit/xla_launch_util.h
+++ b/tensorflow/compiler/jit/xla_launch_util.h
@@ -76,9 +76,15 @@ class XlaComputationLaunchContext {
// Create a new launch context. 'allocate_xla_tensors' is true if allocated
// output tensors and variables are always XlaTensors. If false they are
// assumed to be "normal" device pointers.
+ // If 'use_multiple_streams' is true, tensors may be defined and used on
+ // multiple streams and so se::Events must be defined and waited for. If
+ // 'use_multiple_streams' is true, 'allocate_xla_tensors' must also be true
+ // because we track inter-stream dependencies through events inside XlaTensor
+ // objects.
XlaComputationLaunchContext(xla::LocalClient* client,
xla::DeviceMemoryAllocator* xla_allocator,
- bool allocate_xla_tensors);
+ bool allocate_xla_tensors,
+ bool use_multiple_streams);
// Add all inputs within `ctx` as XLA arguments (returned by arguments()).
// `variables` is a map from TensorFlow argument number to resource variable.
@@ -99,6 +105,7 @@ class XlaComputationLaunchContext {
xla::LocalClient* client_;
xla::DeviceMemoryAllocator* xla_allocator_;
bool allocate_xla_tensors_;
+ bool use_multiple_streams_;
std::vector<std::unique_ptr<xla::ShapedBuffer>> arg_buffers_;
std::vector<xla::ShapedBuffer*> arg_ptrs_;
};
diff --git a/tensorflow/compiler/jit/xla_tensor.cc b/tensorflow/compiler/jit/xla_tensor.cc
index 3c44c4ae6d..5dff187fff 100644
--- a/tensorflow/compiler/jit/xla_tensor.cc
+++ b/tensorflow/compiler/jit/xla_tensor.cc
@@ -73,6 +73,36 @@ Status XlaTensor::AllocateShapedBuffer(DataType dtype, const TensorShape& shape,
return Status::OK();
}
+se::Event* XlaTensor::GetDefinitionEvent(se::Stream* stream) {
+ mutex_lock lock(mu_);
+ if (!definition_event_.has_value()) {
+ return nullptr;
+ }
+
+ // The set of defined streams is expected to be very small indeed (usually
+ // 1-2), so a simple linear scan should be fast enough.
+ if (std::find(streams_defined_on_.begin(), streams_defined_on_.end(),
+ stream) != streams_defined_on_.end()) {
+ // stream is in streams_defined_on_; it doesn't need to be waited on.
+ return nullptr;
+ }
+
+ return &*definition_event_;
+}
+
+void XlaTensor::SetDefinedOn(se::Stream* stream, se::Event event) {
+ mutex_lock lock(mu_);
+ CHECK(!definition_event_.has_value())
+ << "SetDefinedOn must only be called once!";
+ definition_event_ = std::move(event);
+ streams_defined_on_.push_back(stream);
+}
+
+void XlaTensor::SetDefinedOn(se::Stream* stream) {
+ mutex_lock lock(mu_);
+ streams_defined_on_.push_back(stream);
+}
+
// The pointer tag, OR-ed into the XlaTensor's address to distinguish it from
// device-side tensors, which are either CPU or GPU memory pointers. This works
// because we're guaranteed that CPU and GPU pointers are aligned to > 1 bits.
diff --git a/tensorflow/compiler/jit/xla_tensor.h b/tensorflow/compiler/jit/xla_tensor.h
index c54001a999..f7e401c731 100644
--- a/tensorflow/compiler/jit/xla_tensor.h
+++ b/tensorflow/compiler/jit/xla_tensor.h
@@ -85,6 +85,24 @@ class XlaTensor {
host_tensor_.reset(new Tensor(tensor));
}
+ // If the tensor's content is not yet defined on 'stream', and there exists an
+ // se::Event declaring when the tensor's content is defined, return it.
+ // Otherwise, return nullptr. If this function returns nullptr then the
+ // tensor's content can be read on 'stream' without additional
+ // synchronization.
+ se::Event* GetDefinitionEvent(se::Stream* stream);
+
+ // Assert that the tensor's content is defined on 'stream' by the time 'event'
+ // triggers.
+ void SetDefinedOn(se::Stream* stream, se::Event event);
+
+ // Assert that the tensor's content is defined on 'stream'. This version does
+ // not provide an event, and must be called *after* SetDefinedOn(Stream,
+ // Event). This call can be read as an assertion that the definition event has
+ // been waited on by 'stream', so further calls to GetDefinitionEvent(stream)
+ // do not need to also wait on the event.
+ void SetDefinedOn(se::Stream* stream);
+
// Convert from a raw pointer to an XlaTensor, removing the pointer tag.
static XlaTensor* FromOpaquePointer(void* ptr);
// Convert to a raw pointer from an XlaTensor, adding the pointer tag.
@@ -95,6 +113,14 @@ class XlaTensor {
std::unique_ptr<xla::ScopedShapedBuffer> shaped_buffer_;
// An optional host tensor value.
std::unique_ptr<Tensor> host_tensor_;
+ // An optional event that is triggered when the tensor's content has been
+ // defined. If this event is nullptr, it is assumed that the tensor's content
+ // is always defined.
+ gtl::optional<se::Event> definition_event_;
+ // A list of all streams for which the tensor's content is defined for any
+ // newly enqueued command.
+ gtl::InlinedVector<se::Stream*, 2> streams_defined_on_ GUARDED_BY(mu_);
+ mutex mu_;
};
} // namespace tensorflow
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 273641f197..080bed50e6 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -98,6 +98,19 @@ tf_xla_py_test(
)
tf_xla_py_test(
+ name = "adagrad_da_test",
+ size = "small",
+ srcs = ["adagrad_da_test.py"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:training",
+ ],
+)
+
+tf_xla_py_test(
name = "adam_test",
size = "small",
srcs = ["adam_test.py"],
@@ -112,6 +125,48 @@ tf_xla_py_test(
)
tf_xla_py_test(
+ name = "adamax_test",
+ size = "small",
+ srcs = ["adamax_test.py"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/contrib/opt:opt_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:training",
+ ],
+)
+
+tf_xla_py_test(
+ name = "addsign_test",
+ size = "small",
+ srcs = ["addsign_test.py"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/contrib/opt:opt_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:training",
+ ],
+)
+
+tf_xla_py_test(
+ name = "powersign_test",
+ size = "small",
+ srcs = ["powersign_test.py"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/contrib/opt:opt_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:training",
+ ],
+)
+
+tf_xla_py_test(
name = "argminmax_test",
size = "small",
srcs = ["argminmax_test.py"],
@@ -180,7 +235,7 @@ tf_xla_py_test(
tf_xla_py_test(
name = "cholesky_op_test",
- size = "small",
+ size = "medium",
srcs = ["cholesky_op_test.py"],
tags = ["optonly"],
deps = [
@@ -363,7 +418,7 @@ tf_xla_py_test(
tf_xla_py_test(
name = "eager_test",
- size = "small",
+ size = "large",
srcs = ["eager_test.py"],
disabled_backends = [
# TODO(b/78199195) Support XLA CPU devices in eager runtime
@@ -610,6 +665,27 @@ tf_xla_py_test(
)
tf_xla_py_test(
+ name = "qr_op_test",
+ size = "medium",
+ srcs = ["qr_op_test.py"],
+ disabled_backends = [
+ # Test is very slow on CPU.
+ "cpu",
+ "cpu_ondemand",
+ ],
+ tags = ["optonly"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:training",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+tf_xla_py_test(
name = "random_ops_test",
size = "small",
srcs = ["random_ops_test.py"],
@@ -924,7 +1000,7 @@ tf_xla_py_test(
tf_xla_py_test(
name = "sort_ops_test",
- size = "small",
+ size = "medium",
srcs = ["sort_ops_test.py"],
# Times out in fastbuild mode.
tags = ["optonly"],
diff --git a/tensorflow/compiler/tests/adagrad_da_test.py b/tensorflow/compiler/tests/adagrad_da_test.py
new file mode 100644
index 0000000000..dc1625793a
--- /dev/null
+++ b/tensorflow/compiler/tests/adagrad_da_test.py
@@ -0,0 +1,165 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for AdagradDA optimizer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.training import adagrad_da
+
+
+class AdagradDAOptimizerTest(xla_test.XLATestCase):
+
+ def testAdagradDAWithoutRegularizationBasic1(self):
+ for dtype in self.float_types:
+ with self.test_session(), self.test_scope():
+ global_step = resource_variable_ops.ResourceVariable(
+ 0, dtype=dtypes.int64)
+ var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.02], dtype=dtype)
+ opt = adagrad_da.AdagradDAOptimizer(
+ 3.0,
+ global_step,
+ initial_gradient_squared_accumulator_value=0.1,
+ l1_regularization_strength=0.0,
+ l2_regularization_strength=0.0)
+ update = opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]), global_step=global_step)
+ variables.global_variables_initializer().run()
+
+ self.assertAllClose([0.0, 0.0], var0.eval())
+ self.assertAllClose([0.0, 0.0], var1.eval())
+
+ # Run a step of AdagradDA
+ update.run()
+
+ # Let g to be gradient accumulator, gg to be gradient squared
+ # accumulator, T be the global step, lr is the learning rate, and k the
+ # initial gradient squared accumulator value.
+ # w = \dfrac{sign(-g)*lr*|g - l1*T|_{+}}{l2*T*lr + \sqrt{k+gg})}
+ # For -0.1*3.0*(0.1 - 0)/(0 + sqrt(0.1 + 0.1*0.1)) = -0.904534
+ # similarly for others.
+ self.assertAllCloseAccordingToType(
+ np.array([-0.904534, -1.603567]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([-0.094821, -0.189358]), var1.eval())
+
+ def testAdagradDAwithoutRegularizationBasic2(self):
+ for dtype in self.float_types:
+ with self.test_session(), self.test_scope():
+ global_step = resource_variable_ops.ResourceVariable(
+ 0, dtype=dtypes.int64)
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.02], dtype=dtype)
+
+ opt = adagrad_da.AdagradDAOptimizer(
+ 3.0,
+ global_step,
+ initial_gradient_squared_accumulator_value=0.1,
+ l1_regularization_strength=0.0,
+ l2_regularization_strength=0.0)
+ update = opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]), global_step=global_step)
+ variables.global_variables_initializer().run()
+
+ self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
+ self.assertAllCloseAccordingToType([4.0, 3.0], var1.eval())
+
+ # Run a step of AdagradDA
+ update.run()
+
+ self.assertAllCloseAccordingToType(
+ np.array([-0.904534, -1.603567]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([-0.094821, -0.189358]), var1.eval())
+
+ def testAdagradDAWithL1(self):
+ for dtype in self.float_types:
+ with self.test_session(), self.test_scope():
+ global_step = resource_variable_ops.ResourceVariable(
+ 0, dtype=dtypes.int64)
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.02], dtype=dtype)
+
+ opt = adagrad_da.AdagradDAOptimizer(
+ 3.0,
+ global_step,
+ initial_gradient_squared_accumulator_value=0.1,
+ l1_regularization_strength=0.001,
+ l2_regularization_strength=0.0)
+ update = opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]), global_step=global_step)
+ variables.global_variables_initializer().run()
+
+ self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
+ self.assertAllCloseAccordingToType([4.0, 3.0], var1.eval())
+
+ # Run a step of AdagradDA
+ update.run()
+
+ self.assertAllCloseAccordingToType(
+ np.array([-0.895489, -1.59555]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([-0.085339, -0.17989]), var1.eval())
+
+ def testAdagradDAWithL1_L2(self):
+ for dtype in self.float_types:
+ with self.test_session(), self.test_scope():
+ global_step = resource_variable_ops.ResourceVariable(
+ 0, dtype=dtypes.int64)
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.02], dtype=dtype)
+
+ opt = adagrad_da.AdagradDAOptimizer(
+ 3.0,
+ global_step,
+ initial_gradient_squared_accumulator_value=0.1,
+ l1_regularization_strength=0.001,
+ l2_regularization_strength=2.0)
+ update = opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]), global_step=global_step)
+ variables.global_variables_initializer().run()
+
+ self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
+ self.assertAllCloseAccordingToType([4.0, 3.0], var1.eval())
+
+ # Run a step of AdagradDA
+ update.run()
+
+ self.assertAllCloseAccordingToType(
+ np.array([-0.046907, -0.093659]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([-0.004275, -0.009023]), var1.eval())
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/compiler/tests/adamax_test.py b/tensorflow/compiler/tests/adamax_test.py
new file mode 100644
index 0000000000..c4fdbc5974
--- /dev/null
+++ b/tensorflow/compiler/tests/adamax_test.py
@@ -0,0 +1,139 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for AdaMax optimizer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.contrib.opt.python.training import adamax
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+def adamax_update_numpy(param,
+ g_t,
+ t,
+ m,
+ v,
+ alpha=0.001,
+ beta1=0.9,
+ beta2=0.999,
+ epsilon=1e-8):
+ m_t = beta1 * m + (1 - beta1) * g_t
+ v_t = np.maximum(beta2 * v, np.abs(g_t))
+ param_t = param - (alpha / (1 - beta1**t)) * (m_t / (v_t + epsilon))
+ return param_t, m_t, v_t
+
+
+class AdaMaxOptimizerTest(xla_test.XLATestCase):
+
+ def testBasic(self):
+ for i, dtype in enumerate(self.float_types):
+ with self.test_session(), self.test_scope():
+ variable_scope.get_variable_scope().set_use_resource(True)
+ # Initialize variables for numpy implementation.
+ m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype)
+
+ var0 = resource_variable_ops.ResourceVariable(
+ var0_np, name="var0_%d" % i)
+ var1 = resource_variable_ops.ResourceVariable(
+ var1_np, name="var1_%d" % i)
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+
+ opt = adamax.AdaMaxOptimizer()
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ opt_variables = opt.variables()
+ beta1_power = opt._get_beta_accumulators()
+ self.assertTrue(beta1_power is not None)
+ self.assertIn(beta1_power, opt_variables)
+
+ with ops.Graph().as_default():
+ # Shouldn't return non-slot variables from other graphs.
+ self.assertEqual(0, len(opt.variables()))
+
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ beta1_power = opt._get_beta_accumulators()
+
+ # Run 3 steps of AdaMax
+ for t in range(1, 4):
+ update.run()
+
+ self.assertAllCloseAccordingToType(0.9**(t + 1), beta1_power.eval())
+
+ var0_np, m0, v0 = adamax_update_numpy(var0_np, grads0_np, t, m0, v0)
+ var1_np, m1, v1 = adamax_update_numpy(var1_np, grads1_np, t, m1, v1)
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(var0_np, var0.eval(), rtol=1e-2)
+ self.assertAllCloseAccordingToType(var1_np, var1.eval(), rtol=1e-2)
+ self.assertEqual("var0_%d/AdaMax:0" % (i,),
+ opt.get_slot(var=var0, name="m").name)
+
+ def testTensorLearningRate(self):
+ for dtype in self.float_types:
+ with self.test_session(), self.test_scope():
+ variable_scope.get_variable_scope().set_use_resource(True)
+ # Initialize variables for numpy implementation.
+ m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype)
+
+ var0 = resource_variable_ops.ResourceVariable(var0_np)
+ var1 = resource_variable_ops.ResourceVariable(var1_np)
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+ opt = adamax.AdaMaxOptimizer(constant_op.constant(0.001))
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ beta1_power = opt._get_beta_accumulators()
+
+ # Run 3 steps of AdaMax
+ for t in range(1, 4):
+ self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval())
+ update.run()
+
+ var0_np, m0, v0 = adamax_update_numpy(var0_np, grads0_np, t, m0, v0)
+ var1_np, m1, v1 = adamax_update_numpy(var1_np, grads1_np, t, m1, v1)
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(var0_np, var0.eval())
+ self.assertAllCloseAccordingToType(var1_np, var1.eval())
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/compiler/tests/addsign_test.py b/tensorflow/compiler/tests/addsign_test.py
new file mode 100644
index 0000000000..9ec5a964cb
--- /dev/null
+++ b/tensorflow/compiler/tests/addsign_test.py
@@ -0,0 +1,142 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for AddSign."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.contrib.opt.python.training import addsign
+from tensorflow.contrib.opt.python.training import sign_decay
+from tensorflow.python.framework import constant_op
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+def py_linear_decay_fn(decay_steps):
+ def linear_decay(step):
+ step = min(step, decay_steps)
+ return float(decay_steps - step) / decay_steps
+ return linear_decay
+
+
+def addsign_update_numpy(params,
+ g_t,
+ m,
+ lr,
+ alpha=1.0,
+ beta=0.9,
+ py_sign_decay_fn=None,
+ t=None):
+ m_t = beta * m + (1 - beta) * g_t
+ if py_sign_decay_fn is None:
+ sign_decayed = 1.0
+ else:
+ sign_decayed = py_sign_decay_fn(t-1)
+ multiplier = alpha + sign_decayed * np.sign(g_t) * np.sign(m_t)
+ params_t = params - lr * multiplier * g_t
+ return params_t, m_t
+
+
+class AddSignTest(xla_test.XLATestCase):
+
+ def _testDense(self,
+ learning_rate=0.1,
+ sign_decay_fn=None,
+ py_sign_decay_fn=None,
+ alpha=1.0,
+ beta=0.9):
+ for dtype in self.float_types:
+ with self.test_session(), self.test_scope():
+ # Initialize variables for numpy implementation.
+ m0, m1 = 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype)
+
+ var0 = resource_variable_ops.ResourceVariable(var0_np)
+ var1 = resource_variable_ops.ResourceVariable(var1_np)
+ global_step = resource_variable_ops.ResourceVariable(0, trainable=False)
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+
+ opt = addsign.AddSignOptimizer(
+ learning_rate=learning_rate,
+ alpha=alpha,
+ beta=beta,
+ sign_decay_fn=sign_decay_fn,
+ )
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]),
+ global_step=global_step)
+ neg_update = opt.apply_gradients(zip([-grads0, -grads1], [var0, var1]),
+ global_step=global_step)
+ variables.global_variables_initializer().run()
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ # Run 7 steps of AddSign
+ # first 4 steps with positive gradient
+ # last 3 steps with negative gradient (sign(gm) should be -1)
+ for t in range(1, 8):
+ if t < 5:
+ update.run()
+ else:
+ neg_update.run()
+
+ var0_np, m0 = addsign_update_numpy(
+ var0_np,
+ grads0_np if t < 5 else -grads0_np,
+ m0,
+ learning_rate,
+ alpha=alpha,
+ beta=beta,
+ py_sign_decay_fn=py_sign_decay_fn,
+ t=t,
+ )
+ var1_np, m1 = addsign_update_numpy(
+ var1_np,
+ grads1_np if t < 5 else -grads1_np,
+ m1,
+ learning_rate,
+ alpha=alpha,
+ beta=beta,
+ py_sign_decay_fn=py_sign_decay_fn,
+ t=t,
+ )
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(
+ var0_np, var0.eval(), half_rtol=1e-2)
+ self.assertAllCloseAccordingToType(var1_np, var1.eval())
+
+ def testDense(self):
+ decay_steps = 10
+ sign_decay_fn = sign_decay.get_linear_decay_fn(decay_steps)
+ py_sign_decay_fn = py_linear_decay_fn(decay_steps)
+ self._testDense()
+ self._testDense(learning_rate=0.01, alpha=0.1, beta=0.8)
+ self._testDense(
+ sign_decay_fn=sign_decay_fn, py_sign_decay_fn=py_sign_decay_fn)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/compiler/tests/conv2d_test.py b/tensorflow/compiler/tests/conv2d_test.py
index 98d41ba7ed..f9db103f6d 100644
--- a/tensorflow/compiler/tests/conv2d_test.py
+++ b/tensorflow/compiler/tests/conv2d_test.py
@@ -33,12 +33,9 @@ from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import googletest
-
DATA_FORMATS = (
("_data_format_NHWC", "NHWC"),
("_data_format_NCHW", "NCHW"),
- ("_data_format_HWNC", "HWNC"),
- ("_data_format_HWCN", "HWCN"),
)
diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py
index 3524666499..6ead15da13 100644
--- a/tensorflow/compiler/tests/eager_test.py
+++ b/tensorflow/compiler/tests/eager_test.py
@@ -403,7 +403,7 @@ class EagerFunctionTest(xla_test.XLATestCase):
def testSliceInDefun(self):
with self.test_scope():
- @function.defun(compiled=True)
+ @function.defun
def f(x, y):
return x[0::2, y:, ...]
@@ -418,6 +418,22 @@ class EagerFunctionTest(xla_test.XLATestCase):
self.assertAllEqual(np.ones([1, 2, 4]), z.numpy())
self.assertAllEqual((2, 3, 4), dz.shape.as_list())
+ def testNestedDefun(self):
+ self.skipTest('Nested defuns do not work on TPU at the moment')
+ with self.test_scope():
+
+ @function.defun
+ def times_two(x):
+ return 2 * x
+
+ @function.defun
+ def two_x_plus_1(x):
+ return times_two(x) + 1
+
+ x = constant_op.constant([2, 3, 4])
+ y = two_x_plus_1(x)
+ self.assertAllEqual([5, 7, 9], y.numpy())
+
class ExcessivePaddingTest(xla_test.XLATestCase):
"""Test that eager execution works with TPU flattened tensors.
@@ -470,6 +486,36 @@ class ExcessivePaddingTest(xla_test.XLATestCase):
self.assertAllEqual(100 * [[36.0]], reduced)
+def multiple_tpus():
+ devices = context.context().devices()
+ return len([d for d in devices if 'device:TPU:' in d]) > 1
+
+
+class MultiDeviceTest(xla_test.XLATestCase):
+ """Test running TPU computation on more than one core."""
+
+ def testBasic(self):
+ if not multiple_tpus():
+ self.skipTest('MultiDeviceTest requires multiple TPU devices.')
+
+ # Compute 10 on TPU core 0
+ with ops.device('device:TPU:0'):
+ two = constant_op.constant(2)
+ five = constant_op.constant(5)
+ ten = two * five
+ self.assertAllEqual(10, ten)
+
+ # Compute 6 on TPU core 1
+ with ops.device('device:TPU:1'):
+ two = constant_op.constant(2)
+ three = constant_op.constant(3)
+ six = two * three
+ self.assertAllEqual(6, six)
+
+ # Copy 10 and 6 to CPU and sum them
+ self.assertAllEqual(16, ten + six)
+
+
if __name__ == '__main__':
ops.enable_eager_execution(
config=config_pb2.ConfigProto(log_device_placement=True))
diff --git a/tensorflow/compiler/tests/powersign_test.py b/tensorflow/compiler/tests/powersign_test.py
new file mode 100644
index 0000000000..5fa7706d72
--- /dev/null
+++ b/tensorflow/compiler/tests/powersign_test.py
@@ -0,0 +1,142 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for PowerSign."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import numpy as np
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.contrib.opt.python.training import powersign
+from tensorflow.contrib.opt.python.training import sign_decay
+from tensorflow.python.framework import constant_op
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+def py_linear_decay_fn(decay_steps):
+ def linear_decay(step):
+ step = min(step, decay_steps)
+ return float(decay_steps - step) / decay_steps
+ return linear_decay
+
+
+def powersign_update_numpy(params,
+ g_t,
+ m,
+ lr,
+ base=math.e,
+ beta=0.9,
+ py_sign_decay_fn=None,
+ t=None):
+ m_t = beta * m + (1 - beta) * g_t
+ if py_sign_decay_fn is None:
+ sign_decayed = 1.0
+ else:
+ sign_decayed = py_sign_decay_fn(t-1)
+ multiplier = base ** (sign_decayed * np.sign(g_t) * np.sign(m_t))
+ params_t = params - lr * multiplier * g_t
+ return params_t, m_t
+
+
+class PowerSignTest(xla_test.XLATestCase):
+
+ def _testDense(self,
+ learning_rate=0.1,
+ sign_decay_fn=None,
+ py_sign_decay_fn=None,
+ base=math.e,
+ beta=0.9):
+ for dtype in self.float_types:
+ with self.test_session(), self.test_scope():
+ # Initialize variables for numpy implementation.
+ m0, m1 = 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype)
+
+ var0 = resource_variable_ops.ResourceVariable(var0_np)
+ var1 = resource_variable_ops.ResourceVariable(var1_np)
+ global_step = resource_variable_ops.ResourceVariable(0, trainable=False)
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+
+ opt = powersign.PowerSignOptimizer(
+ learning_rate=learning_rate,
+ base=base,
+ beta=beta,
+ sign_decay_fn=sign_decay_fn,
+ )
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]),
+ global_step=global_step)
+ neg_update = opt.apply_gradients(zip([-grads0, -grads1], [var0, var1]),
+ global_step=global_step)
+
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ # Run 7 steps of powersign
+ # first 4 steps with positive gradient
+ # last 3 steps with negative gradient (sign(gm) should be -1)
+ for t in range(1, 8):
+ if t < 5:
+ update.run()
+ else:
+ neg_update.run()
+
+ var0_np, m0 = powersign_update_numpy(
+ var0_np,
+ grads0_np if t < 5 else -grads0_np,
+ m0,
+ learning_rate,
+ base=base,
+ beta=beta,
+ py_sign_decay_fn=py_sign_decay_fn,
+ t=t,
+ )
+ var1_np, m1 = powersign_update_numpy(
+ var1_np,
+ grads1_np if t < 5 else -grads1_np,
+ m1,
+ learning_rate,
+ base=base,
+ beta=beta,
+ py_sign_decay_fn=py_sign_decay_fn,
+ t=t,
+ )
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(var0_np, var0.eval())
+ self.assertAllCloseAccordingToType(var1_np, var1.eval())
+
+ def testDense(self):
+ decay_steps = 10
+ sign_decay_fn = sign_decay.get_linear_decay_fn(decay_steps)
+ py_sign_decay_fn = py_linear_decay_fn(decay_steps)
+ self._testDense()
+ self._testDense(learning_rate=0.1, base=10.0, beta=0.8)
+ self._testDense(
+ sign_decay_fn=sign_decay_fn, py_sign_decay_fn=py_sign_decay_fn)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/compiler/tests/qr_op_test.py b/tensorflow/compiler/tests/qr_op_test.py
new file mode 100644
index 0000000000..93752a21db
--- /dev/null
+++ b/tensorflow/compiler/tests/qr_op_test.py
@@ -0,0 +1,112 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.ops.math_ops.matrix_inverse."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import itertools
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class QrOpTest(xla_test.XLATestCase, parameterized.TestCase):
+
+ def AdjustedNorm(self, x):
+ """Computes the norm of matrices in 'x', adjusted for dimension and type."""
+ norm = np.linalg.norm(x, axis=(-2, -1))
+ return norm / (max(x.shape[-2:]) * np.finfo(x.dtype).eps)
+
+ def CompareOrthogonal(self, x, y, rank):
+ # We only compare the first 'rank' orthogonal vectors since the
+ # remainder form an arbitrary orthonormal basis for the
+ # (row- or column-) null space, whose exact value depends on
+ # implementation details. Notice that since we check that the
+ # matrices of singular vectors are unitary elsewhere, we do
+ # implicitly test that the trailing vectors of x and y span the
+ # same space.
+ x = x[..., 0:rank]
+ y = y[..., 0:rank]
+ # Q is only unique up to sign (complex phase factor for complex matrices),
+ # so we normalize the sign first.
+ sum_of_ratios = np.sum(np.divide(y, x), -2, keepdims=True)
+ phases = np.divide(sum_of_ratios, np.abs(sum_of_ratios))
+ x *= phases
+ self.assertTrue(np.all(self.AdjustedNorm(x - y) < 30.0))
+
+ def CheckApproximation(self, a, q, r):
+ # Tests that a ~= q*r.
+ precision = self.AdjustedNorm(a - np.matmul(q, r))
+ self.assertTrue(np.all(precision < 5.0))
+
+ def CheckUnitary(self, x):
+ # Tests that x[...,:,:]^H * x[...,:,:] is close to the identity.
+ xx = math_ops.matmul(x, x, adjoint_a=True)
+ identity = array_ops.matrix_band_part(array_ops.ones_like(xx), 0, 0)
+ precision = self.AdjustedNorm(xx.eval() - identity.eval())
+ self.assertTrue(np.all(precision < 5.0))
+
+ def _test(self, dtype, shape, full_matrices):
+ np.random.seed(1)
+ x_np = np.random.uniform(
+ low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype)
+
+ with self.test_session() as sess:
+ x_tf = array_ops.placeholder(dtype)
+ with self.test_scope():
+ q_tf, r_tf = linalg_ops.qr(x_tf, full_matrices=full_matrices)
+ q_tf_val, r_tf_val = sess.run([q_tf, r_tf], feed_dict={x_tf: x_np})
+
+ q_dims = q_tf_val.shape
+ np_q = np.ndarray(q_dims, dtype)
+ np_q_reshape = np.reshape(np_q, (-1, q_dims[-2], q_dims[-1]))
+ new_first_dim = np_q_reshape.shape[0]
+
+ x_reshape = np.reshape(x_np, (-1, x_np.shape[-2], x_np.shape[-1]))
+ for i in range(new_first_dim):
+ if full_matrices:
+ np_q_reshape[i, :, :], _ = np.linalg.qr(
+ x_reshape[i, :, :], mode="complete")
+ else:
+ np_q_reshape[i, :, :], _ = np.linalg.qr(
+ x_reshape[i, :, :], mode="reduced")
+ np_q = np.reshape(np_q_reshape, q_dims)
+ self.CompareOrthogonal(np_q, q_tf_val, min(shape[-2:]))
+ self.CheckApproximation(x_np, q_tf_val, r_tf_val)
+ self.CheckUnitary(q_tf_val)
+
+ SIZES = [1, 2, 5, 10, 32, 100, 300]
+ DTYPES = [np.float32]
+ PARAMS = itertools.product(SIZES, SIZES, DTYPES)
+
+ @parameterized.parameters(*PARAMS)
+ def testQR(self, rows, cols, dtype):
+ # TODO(b/111317468): implement full_matrices=False, test other types.
+ for full_matrices in [True]:
+ # Only tests the (3, 2) case for small numbers of rows/columns.
+ for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10):
+ self._test(dtype, batch_dims + (rows, cols), full_matrices)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py
index b880b2a3fe..14c5e7a975 100644
--- a/tensorflow/compiler/tests/random_ops_test.py
+++ b/tensorflow/compiler/tests/random_ops_test.py
@@ -140,10 +140,10 @@ class RandomOpsTest(xla_test.XLATestCase):
def testShuffle1d(self):
with self.test_session() as sess:
with self.test_scope():
- x = math_ops.range(20)
+ x = math_ops.range(1 << 16)
shuffle = random_ops.random_shuffle(x)
result = sess.run(shuffle)
- expected = range(20)
+ expected = range(1 << 16)
# Compare sets to avoid randomness behavior changes but make sure still
# have all the values.
self.assertAllEqual(set(result), set(expected))
diff --git a/tensorflow/compiler/tests/rmsprop_test.py b/tensorflow/compiler/tests/rmsprop_test.py
index 9489fded32..ff8bbac911 100644
--- a/tensorflow/compiler/tests/rmsprop_test.py
+++ b/tensorflow/compiler/tests/rmsprop_test.py
@@ -30,31 +30,102 @@ from tensorflow.python.training import rmsprop
class RmspropTest(xla_test.XLATestCase):
+ def _rmsprop_update_numpy(self,
+ var,
+ g,
+ mg,
+ rms,
+ mom,
+ lr,
+ decay=0.9,
+ momentum=0.0,
+ epsilon=1e-10,
+ centered=False):
+ rms_t = rms * decay + (1 - decay) * g * g
+ denom_t = rms_t + epsilon
+ if centered:
+ mg_t = mg * decay + (1 - decay) * g
+ denom_t -= mg_t * mg_t
+ else:
+ mg_t = mg
+ mom_t = momentum * mom + lr * g / np.sqrt(denom_t, dtype=denom_t.dtype)
+ var_t = var - mom_t
+ return var_t, mg_t, rms_t, mom_t
+
def testBasic(self):
for dtype in self.float_types:
- with self.test_session(), self.test_scope():
- var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
- var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
- grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
- grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
- rms_opt = rmsprop.RMSPropOptimizer(3.0)
- rms_update = rms_opt.apply_gradients(
- zip([grads0, grads1], [var0, var1]))
- variables.global_variables_initializer().run()
-
- # Fetch params to validate initial values
- self.assertAllClose([1.0, 2.0], var0.eval())
- self.assertAllClose([3.0, 4.0], var1.eval())
-
- # Run 3 steps of RMSProp
- for _ in range(3):
- rms_update.run()
-
- # Validate updated params
- self.assertAllCloseAccordingToType(
- np.array([2.91705132e-04, 1.00029182e+00]), var0.eval())
- self.assertAllCloseAccordingToType(
- np.array([2.89990854, 3.89990854]), var1.eval())
+ for centered in [False, True]:
+ with self.test_session(), self.test_scope():
+ # Initialize variables for numpy implementation.
+ var0_np = np.array([1.0, 2.0], dtype=dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype)
+ mg0_np = np.array([0.0, 0.0], dtype=dtype)
+ mg1_np = np.array([0.0, 0.0], dtype=dtype)
+ rms0_np = np.array([1.0, 1.0], dtype=dtype)
+ rms1_np = np.array([1.0, 1.0], dtype=dtype)
+ mom0_np = np.array([0.0, 0.0], dtype=dtype)
+ mom1_np = np.array([0.0, 0.0], dtype=dtype)
+
+ var0 = resource_variable_ops.ResourceVariable(var0_np)
+ var1 = resource_variable_ops.ResourceVariable(var1_np)
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+ learning_rate = 3.0
+ rms_opt = rmsprop.RMSPropOptimizer(learning_rate, centered=centered)
+ rms_update = rms_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ mg0 = rms_opt.get_slot(var0, "mg")
+ self.assertEqual(mg0 is not None, centered)
+ mg1 = rms_opt.get_slot(var1, "mg")
+ self.assertEqual(mg1 is not None, centered)
+ rms0 = rms_opt.get_slot(var0, "rms")
+ self.assertTrue(rms0 is not None)
+ rms1 = rms_opt.get_slot(var1, "rms")
+ self.assertTrue(rms1 is not None)
+ mom0 = rms_opt.get_slot(var0, "momentum")
+ self.assertTrue(mom0 is not None)
+ mom1 = rms_opt.get_slot(var1, "momentum")
+ self.assertTrue(mom1 is not None)
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ # Run 3 steps of RMSProp
+ for _ in range(3):
+ rms_update.run()
+
+ var0_np, mg0_np, rms0_np, mom0_np = self._rmsprop_update_numpy(
+ var0_np,
+ grads0_np,
+ mg0_np,
+ rms0_np,
+ mom0_np,
+ learning_rate,
+ centered=centered)
+ var1_np, mg1_np, rms1_np, mom1_np = self._rmsprop_update_numpy(
+ var1_np,
+ grads1_np,
+ mg1_np,
+ rms1_np,
+ mom1_np,
+ learning_rate,
+ centered=centered)
+
+ # Validate updated params
+ if centered:
+ self.assertAllCloseAccordingToType(mg0_np, mg0.eval())
+ self.assertAllCloseAccordingToType(mg1_np, mg1.eval())
+ self.assertAllCloseAccordingToType(rms0_np, rms0.eval())
+ self.assertAllCloseAccordingToType(rms1_np, rms1.eval())
+ self.assertAllCloseAccordingToType(mom0_np, mom0.eval())
+ self.assertAllCloseAccordingToType(mom1_np, mom1.eval())
+ self.assertAllCloseAccordingToType(var0_np, var0.eval())
+ self.assertAllCloseAccordingToType(var1_np, var1.eval())
if __name__ == "__main__":
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index 40e32f2e75..ff002d15b0 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -139,12 +139,14 @@ cc_library(
"xla_op_registry.cc",
"xla_resource.cc",
"xla_cpu_backend.cc",
+ "legacy_flags/backend_registration_flags.cc",
] + if_cuda_is_configured([
"xla_gpu_backend.cc",
]),
hdrs = [
"const_analysis.h",
"graph_compiler.h",
+ "legacy_flags/backend_registration_flags.h",
"xla_compilation_device.h",
"xla_compiler.h",
"xla_context.h",
@@ -162,7 +164,7 @@ cc_library(
":sharding_util",
":tf2xla_util",
"//tensorflow/compiler/tf2xla/lib:util",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
@@ -175,9 +177,11 @@ cc_library(
"//tensorflow/compiler/xla/client/lib:numeric",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
+ "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
@@ -202,7 +206,7 @@ cc_library(
],
visibility = [":friends"],
deps = [
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:core_cpu_internal",
@@ -285,6 +289,7 @@ tf_cc_test(
deps = [
":tf2xla",
":tf2xla_proto",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:client_library",
@@ -327,7 +332,7 @@ tf_cc_test(
"//tensorflow/cc:ops",
"//tensorflow/cc:resource_variable_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla/client:client_library",
@@ -364,6 +369,7 @@ tf_cc_test(
],
deps = [
":common",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/core:framework",
"//tensorflow/core:test",
diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc
index 4900af6df1..e1cea03865 100644
--- a/tensorflow/compiler/tf2xla/graph_compiler.cc
+++ b/tensorflow/compiler/tf2xla/graph_compiler.cc
@@ -161,9 +161,8 @@ Status GraphCompiler::Compile() {
outputs.resize(n->num_outputs());
for (int o = 0; o < n->num_outputs(); ++o) {
outputs[o] = op_context.release_output(o);
- if (*op_context.is_output_dead() || outputs[o].tensor == nullptr) {
+ if (outputs[o].tensor == nullptr) {
return errors::Internal("Missing xla_context ", o, "-th output from ",
- (*op_context.is_output_dead() ? "(dead)" : ""),
SummarizeNode(*n));
}
}
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index a8eb7d942d..5a335aa43c 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -58,6 +58,7 @@ tf_kernel_library(
"pack_op.cc",
"pad_op.cc",
"pooling_ops.cc",
+ "qr_op.cc",
"quantize_and_dequantize_op.cc",
"random_ops.cc",
"reduce_window_op.cc",
@@ -107,6 +108,7 @@ tf_kernel_library(
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/lib:batch_dot",
"//tensorflow/compiler/tf2xla/lib:cholesky",
+ "//tensorflow/compiler/tf2xla/lib:qr",
"//tensorflow/compiler/tf2xla/lib:random",
"//tensorflow/compiler/tf2xla/lib:scatter",
"//tensorflow/compiler/tf2xla/lib:triangular_solve",
@@ -114,6 +116,7 @@ tf_kernel_library(
"//tensorflow/compiler/tf2xla/lib:while_loop",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
"//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@@ -159,7 +162,7 @@ tf_kernel_library(
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@@ -175,7 +178,7 @@ tf_kernel_library(
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@@ -210,6 +213,7 @@ tf_kernel_library(
":index_ops_kernel_argmax_float_2d",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client/lib:arithmetic",
diff --git a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc
index ee2c920453..ba3b1c9dab 100644
--- a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/bcast.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/elu_op.cc b/tensorflow/compiler/tf2xla/kernels/elu_op.cc
index 2c76bcee25..81f42e504e 100644
--- a/tensorflow/compiler/tf2xla/kernels/elu_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/elu_op.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/no_op.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
index a020ebc729..22a45b2a11 100644
--- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
+++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@@ -78,14 +78,14 @@ class ArgMaxCustomCallOp : public XlaOpKernel {
std::vector<xla::XlaOp> args;
args.push_back(ctx->Input(0));
args.push_back(xla::ConstantLiteral(
- &b, *xla::Literal::CreateR1<int64>(input_shape.dim_sizes())));
+ &b, *xla::LiteralUtil::CreateR1<int64>(input_shape.dim_sizes())));
if (input_shape.dims() > 1) {
// Don't bother passing the output shape and dim for the 1d case, since
// the shape is always a scalar and the dim is always 0.
args.push_back(xla::ConstantLiteral(
- &b, *xla::Literal::CreateR1<int64>(output_shape.dim_sizes())));
+ &b, *xla::LiteralUtil::CreateR1<int64>(output_shape.dim_sizes())));
args.push_back(
- xla::ConstantLiteral(&b, *xla::Literal::CreateR0<int32>(dim)));
+ xla::ConstantLiteral(&b, *xla::LiteralUtil::CreateR0<int32>(dim)));
}
xla::Shape xla_shape =
diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
index a81f5fddf6..12d9cb9bac 100644
--- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/qr_op.cc b/tensorflow/compiler/tf2xla/kernels/qr_op.cc
new file mode 100644
index 0000000000..de9068a640
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/qr_op.cc
@@ -0,0 +1,47 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/lib/qr.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+
+namespace tensorflow {
+namespace {
+
+class QROp : public XlaOpKernel {
+ public:
+ explicit QROp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ bool full_matrices;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("full_matrices", &full_matrices));
+ OP_REQUIRES(
+ ctx, full_matrices,
+ errors::Unimplemented("full_matrices=False case of QR decomposition is "
+ "not implemented in TF/XLA"));
+ }
+ void Compile(XlaOpKernelContext* ctx) override {
+ auto result = QRDecomposition(ctx->Input(0));
+ if (!result.ok()) {
+ ctx->SetStatus(result.status());
+ return;
+ }
+ ctx->SetOutput(0, result.ValueOrDie().q);
+ ctx->SetOutput(1, result.ValueOrDie().r);
+ }
+};
+
+REGISTER_XLA_OP(Name("Qr").TypeConstraint("T", kFloatTypes), QROp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc
index 9a0a7f9b90..607cad798a 100644
--- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc
@@ -74,56 +74,121 @@ class RandomShuffleOp : public XlaOpKernel {
for (tensorflow::TensorShapeDim dimension : input_shape) {
num_elements *= dimension.size;
}
+
if (num_elements <= 1 || n <= 1) {
// No shuffling is required, so copy input directly to output
ctx->SetOutput(0, input);
- } else {
- // Generate the random swaps for the indices.
- auto swaps_shape = xla::ShapeUtil::MakeShape(xla::S32, {n});
- auto swaps =
- xla::RngUniform(xla::ConstantR0<int32>(builder, 0),
- xla::ConstantR0<int32>(builder, n), swaps_shape);
-
- // Generate range(n) as the initial value for the indices to be swapped.
- xla::XlaOp indices = xla::Iota(builder, xla::S32, n);
-
- // Swap the indices at i and swaps[i].
- auto swap_body_fn = [&](xla::XlaOp i,
- gtl::ArraySlice<xla::XlaOp> loop_vars,
- xla::XlaBuilder* builder)
- -> xla::StatusOr<std::vector<xla::XlaOp>> {
- auto swaps = loop_vars[0];
- auto indices = loop_vars[1];
- i = xla::Reshape(i, {1});
- // temp = indices[i]
- auto temp = xla::DynamicSlice(indices, i, {1});
- // swap_index = swaps[i]
- auto swap_index = xla::DynamicSlice(swaps, i, {1});
- // swap_value = indices[swaps[i]]
- auto swap_value = xla::DynamicSlice(indices, swap_index, {1});
- // indices[i] = indices[swaps[i]]
- indices = xla::DynamicUpdateSlice(indices, swap_value, i);
- // indices[swaps[i]] = temp
- indices = xla::DynamicUpdateSlice(indices, temp, swap_index);
- return std::vector<xla::XlaOp>{swaps, indices};
- };
- // for i in range(n):
- auto swap_loop_result =
- XlaForEachIndex(n, xla::S32, swap_body_fn, {swaps, indices},
- "indices_swap_loop", builder)
- .ValueOrDie();
- auto swapped_indices = swap_loop_result[1];
-
- // Gather the data using the swapped indices as the shuffled order.
- auto indices_tensor_shape = TensorShape({n});
- DataType type = ctx->expected_output_dtype(0);
- xla::XlaOp gather;
- OP_REQUIRES_OK(ctx, XlaGather(input, input_shape, swapped_indices,
- indices_tensor_shape,
- /*axis=*/0, /*indices_are_nd=*/false, type,
- DT_INT32, builder, &gather));
- ctx->SetOutput(0, gather);
+ return;
+ }
+
+ if (input_shape.dims() == 1) {
+ // For R1s, shuffle values by sorting instead of the obvious Fisher-Yates
+ // algorithm. Fisher-Yates is simple to implement and correct, but not
+ // easily parallelizable. For a sufficiently parallel architecture, it is
+ // faster to sort many times, than Fisher-Yates shuffle once.
+
+ // Shuffle values by assigning each value a random key and sorting the
+ // keys. Keys can collide causing detectable patterns in the shuffled
+ // output. Collisions translates into more ascending sub-sequences in the
+ // shuffled output than would be expected by chance. To avoid collisions,
+ // the number of possible key values must be sufficiently large.
+
+ // How are more than 2^32 keys created? In each loop iteration, the
+ // algorithm sorts by random keys. Conceptually, the earlier iterations
+ // are sorting on the lower-order bits of larger keys that are never
+ // actually assembled.
+
+ // The expected number of collisions is n - d + d(1 - 1/d)^n, where d is
+ // the number of possible keys and n is the number of values. If d = n^2,
+ // then the limit as n goes to infinity is 1/2. If d = n^3, then the limit
+ // as n goes to infinity is zero.
+
+ // This implementation ensures that the key-space is greater than or equal
+ // to the cube of the number of values. The risk of collisions can be
+ // further reduced by increasing Exponent at the expense of
+ // performance.
+
+ // For Exponent = 2, the expected number of collisions per shuffle is
+ // maximized at n = floor((2^32-1)^(1/2)) = 65535 where the expectation is
+ // about 1/2.
+
+ // For Exponent = 3, the expected number of collisions per shuffle is
+ // maximized at n = floor((2^32-1)^(1/3)) = 1625 where the expectation is
+ // about 1/3255.
+
+ // For Exponent = 4, the expected number of collisions per shuffle is
+ // maximized at n = floor((2^32-1)^(1/4)) = 255 where the expectation is
+ // about 1/132622.
+ constexpr int Exponent = 3;
+ const int rounds = static_cast<int>(
+ std::ceil(Exponent * std::log(num_elements) / std::log(kuint32max)));
+
+ const xla::Shape key_shape =
+ xla::ShapeUtil::MakeShape(xla::U32, {num_elements});
+ xla::XlaOp zero = xla::ConstantR0(builder, 0U);
+
+ // Unfortunately, xla::RngUniform gives values in the half open interval
+ // rather than the closed interval, so instead of 2^32 possible keys there
+ // are only 2^32 - 1 (kuint32max).
+ xla::XlaOp max_value = xla::ConstantR0(builder, kuint32max);
+
+ xla::XlaOp curr = input;
+ for (int i = 0; i < rounds; ++i) {
+ xla::XlaOp keys = xla::RngUniform(zero, max_value, key_shape);
+ xla::XlaOp sorted = xla::Sort(keys, curr);
+ curr = xla::GetTupleElement(sorted, 1);
+ }
+
+ ctx->SetOutput(0, curr);
+ return;
}
+
+ // The Fisher-Yates algorithm.
+
+ // Generate the random swaps for the indices.
+ auto swaps_shape = xla::ShapeUtil::MakeShape(xla::S32, {n});
+ auto swaps =
+ xla::RngUniform(xla::ConstantR0<int32>(builder, 0),
+ xla::ConstantR0<int32>(builder, n), swaps_shape);
+
+ // Generate range(n) as the initial value for the indices to be swapped.
+ xla::XlaOp indices = xla::Iota(builder, xla::S32, n);
+
+ // Swap the indices at i and swaps[i].
+ auto swap_body_fn = [&](xla::XlaOp i, gtl::ArraySlice<xla::XlaOp> loop_vars,
+ xla::XlaBuilder* builder)
+ -> xla::StatusOr<std::vector<xla::XlaOp>> {
+ auto swaps = loop_vars[0];
+ auto indices = loop_vars[1];
+ i = xla::Reshape(i, {1});
+ // temp = indices[i]
+ auto temp = xla::DynamicSlice(indices, i, {1});
+ // swap_index = swaps[i]
+ auto swap_index = xla::DynamicSlice(swaps, i, {1});
+ // swap_value = indices[swaps[i]]
+ auto swap_value = xla::DynamicSlice(indices, swap_index, {1});
+ // indices[i] = indices[swaps[i]]
+ indices = xla::DynamicUpdateSlice(indices, swap_value, i);
+ // indices[swaps[i]] = temp
+ indices = xla::DynamicUpdateSlice(indices, temp, swap_index);
+ return std::vector<xla::XlaOp>{swaps, indices};
+ };
+ // for i in range(n):
+ auto swap_loop_result =
+ XlaForEachIndex(n, xla::S32, swap_body_fn, {swaps, indices},
+ "indices_swap_loop", builder)
+ .ValueOrDie();
+ auto swapped_indices = swap_loop_result[1];
+
+ // Gather the data using the swapped indices as the shuffled order.
+ auto indices_tensor_shape = TensorShape({n});
+ DataType type = ctx->expected_output_dtype(0);
+ xla::XlaOp gather;
+ OP_REQUIRES_OK(ctx, XlaGather(input, input_shape, swapped_indices,
+ indices_tensor_shape,
+ /*axis=*/0, /*indices_are_nd=*/false, type,
+ DT_INT32, builder, &gather));
+ ctx->SetOutput(0, gather);
}
private:
@@ -220,5 +285,5 @@ REGISTER_XLA_OP(Name("TruncatedNormal")
.TypeConstraint("dtype", DT_FLOAT),
TruncatedNormalOp);
-} // anonymous namespace
+} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc
index 46fae59ad4..be7f2bce8c 100644
--- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
namespace tensorflow {
diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
index 909783ecb3..ed1d1c6610 100644
--- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
namespace tensorflow {
diff --git a/tensorflow/compiler/tf2xla/kernels/relu_op.cc b/tensorflow/compiler/tf2xla/kernels/relu_op.cc
index a4ba6c748a..f4b804e546 100644
--- a/tensorflow/compiler/tf2xla/kernels/relu_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/relu_op.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/no_op.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc
index e0ca8dd8e2..354fec9be7 100644
--- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc
index 037c422258..ec15b4cc7a 100644
--- a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc
index 76924c6a01..27ab3e1bf5 100644
--- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/register_types.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc
index bc3d0bf5df..25a5bcbe1d 100644
--- a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc
index ca74cf2450..242638f981 100644
--- a/tensorflow/compiler/tf2xla/kernels/split_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc
index 591e61b4c8..df91900570 100644
--- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc
@@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/register_types.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
index 2f650ce305..26326f18b8 100644
--- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
@@ -26,7 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/tf2xla/xla_resource.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/register_types.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/topk_op.cc b/tensorflow/compiler/tf2xla/kernels/topk_op.cc
index 9962f1207d..1ddcb08c8e 100644
--- a/tensorflow/compiler/tf2xla/kernels/topk_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/topk_op.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/numeric.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/no_op.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc
index bef6161e85..98df730249 100644
--- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/no_op.h"
@@ -268,6 +268,83 @@ REGISTER_XLA_OP(
Name("ResourceApplyProximalAdagrad").TypeConstraint("T", kFloatTypes),
ResourceApplyProximalAdagrad);
+class ResourceApplyAdagradDA : public XlaOpKernel {
+ public:
+ explicit ResourceApplyAdagradDA(OpKernelConstruction* ctx)
+ : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ TensorShape var_shape, accum_shape, squared_accum_shape;
+ xla::XlaOp var, accum, squared_accum;
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var));
+ OP_REQUIRES_OK(ctx,
+ ctx->ReadVariableInput(1, dtype_, &accum_shape, &accum));
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &squared_accum_shape,
+ &squared_accum));
+ OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
+ errors::InvalidArgument(
+ "var and accum do not have the same shape",
+ var_shape.DebugString(), " ", accum_shape.DebugString()));
+ OP_REQUIRES(
+ ctx, var_shape.IsSameSize(squared_accum_shape),
+ errors::InvalidArgument(
+ "var and squared accum do not have the same shape",
+ var_shape.DebugString(), " ", squared_accum_shape.DebugString()));
+
+ TensorShape grad_shape = ctx->InputShape(3);
+ TensorShape lr_shape = ctx->InputShape(4);
+ TensorShape l1_shape = ctx->InputShape(5);
+ TensorShape l2_shape = ctx->InputShape(6);
+ TensorShape global_step_shape = ctx->InputShape(7);
+
+ OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
+ errors::InvalidArgument(
+ "var and grad do not have the same shape",
+ var_shape.DebugString(), " ", grad_shape.DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
+ errors::InvalidArgument("lr is not a scalar: ",
+ lr_shape.DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l1_shape),
+ errors::InvalidArgument("l1 is not a scalar: ",
+ l1_shape.DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l2_shape),
+ errors::InvalidArgument("l2 is not a scalar: ",
+ l2_shape.DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(global_step_shape),
+ errors::InvalidArgument("global step is not a scalar: ",
+ global_step_shape.DebugString()));
+
+ xla::XlaOp grad = ctx->Input(3);
+ xla::XlaOp lr = ctx->Input(4);
+ xla::XlaOp l1 = ctx->Input(5);
+ xla::XlaOp l2 = ctx->Input(6);
+ xla::XlaBuilder* const b = ctx->builder();
+ xla::XlaOp global_step =
+ XlaHelpers::ConvertElementType(b, ctx->Input(7), dtype_);
+
+ accum = accum + grad;
+ squared_accum = squared_accum + xla::Square(grad);
+ xla::XlaOp zero = xla::ScalarLike(lr, 0.0);
+ xla::XlaOp denominator = global_step * lr * l2 + xla::Sqrt(squared_accum);
+ xla::XlaOp l1_le_zero = -lr * accum / denominator;
+ xla::XlaOp l1_gt_zero = -lr * xla::Sign(accum) *
+ xla::Max(xla::Abs(accum) - global_step * l1, zero) /
+ denominator;
+
+ var = xla::Select(xla::Gt(l1, zero), l1_gt_zero, l1_le_zero);
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var));
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, accum));
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype_, squared_accum));
+ }
+
+ private:
+ DataType dtype_;
+};
+REGISTER_XLA_OP(Name("ResourceApplyAdagradDA").TypeConstraint("T", kFloatTypes),
+ ResourceApplyAdagradDA);
+
class ResourceApplyAdam : public XlaOpKernel {
public:
explicit ResourceApplyAdam(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
@@ -353,36 +430,112 @@ class ResourceApplyAdam : public XlaOpKernel {
REGISTER_XLA_OP(Name("ResourceApplyAdam").TypeConstraint("T", kFloatTypes),
ResourceApplyAdam);
-class ResourceApplyRMSProp : public XlaOpKernel {
+class ResourceApplyAdaMax : public XlaOpKernel {
public:
- explicit ResourceApplyRMSProp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+ explicit ResourceApplyAdaMax(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ }
void Compile(XlaOpKernelContext* ctx) override {
- DataType type = ctx->input_type(3);
+ TensorShape var_shape, m_shape, v_shape;
+ xla::XlaOp var, m, v;
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var));
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype_, &m_shape, &m));
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &v_shape, &v));
- TensorShape var_shape, ms_shape, mom_shape;
- xla::XlaOp var, ms, mom;
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var));
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &ms_shape, &ms));
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, type, &mom_shape, &mom));
+ TensorShape beta1_power_shape = ctx->InputShape(3);
+ TensorShape lr_shape = ctx->InputShape(4);
+ TensorShape beta1_shape = ctx->InputShape(5);
+ TensorShape beta2_shape = ctx->InputShape(6);
+ TensorShape epsilon_shape = ctx->InputShape(7);
+ TensorShape grad_shape = ctx->InputShape(8);
- TensorShape lr_shape = ctx->InputShape(3);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_power_shape),
+ errors::InvalidArgument("beta1_power is not a scalar: ",
+ beta1_power_shape.DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
+ errors::InvalidArgument("lr is not a scalar : ",
+ lr_shape.DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_shape),
+ errors::InvalidArgument("beta1 is not a scalar: ",
+ beta1_shape.DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2_shape),
+ errors::InvalidArgument("beta2 is not a scalar: ",
+ beta2_shape.DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_shape),
+ errors::InvalidArgument("epsilon is not a scalar: ",
+ epsilon_shape.DebugString()));
+ OP_REQUIRES(ctx, var_shape.IsSameSize(m_shape),
+ errors::InvalidArgument("var and m do not have the same shape",
+ var_shape.DebugString(), " ",
+ m_shape.DebugString()));
+ OP_REQUIRES(ctx, var_shape.IsSameSize(v_shape),
+ errors::InvalidArgument("var and v do not have the same shape",
+ var_shape.DebugString(), " ",
+ v_shape.DebugString()));
+ OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
+ errors::InvalidArgument(
+ "var and grad do not have the same shape",
+ var_shape.DebugString(), " ", grad_shape.DebugString()));
+
+ xla::XlaOp beta1_power = ctx->Input(3);
+ xla::XlaOp lr = ctx->Input(4);
+ xla::XlaOp beta1 = ctx->Input(5);
+ xla::XlaOp beta2 = ctx->Input(6);
+ xla::XlaOp epsilon = ctx->Input(7);
+ xla::XlaOp grad = ctx->Input(8);
+
+ xla::XlaOp one = xla::ScalarLike(lr, 1.0);
+ m = beta1 * m + (one - beta1) * grad;
+ v = xla::Max(beta2 * v, xla::Abs(grad));
+ var = var - lr / (one - beta1_power) * (m / (v + epsilon));
+
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var));
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, m));
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype_, v));
+ }
+
+ private:
+ DataType dtype_;
+};
+REGISTER_XLA_OP(Name("ResourceApplyAdaMax").TypeConstraint("T", kFloatTypes),
+ ResourceApplyAdaMax);
+
+class ResourceApplyRMSProp : public XlaOpKernel {
+ public:
+ explicit ResourceApplyRMSProp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ TensorShape var_shape, ms_shape, mom_shape, mg_shape;
+ xla::XlaOp var, ms, mom, mg;
+ OP_REQUIRES_OK(ctx,
+ ctx->ReadVariableInput("var", dtype_, &var_shape, &var));
+ if (centered_) {
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput("mg", dtype_, &mg_shape, &mg));
+ }
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput("ms", dtype_, &ms_shape, &ms));
+ OP_REQUIRES_OK(ctx,
+ ctx->ReadVariableInput("mom", dtype_, &mom_shape, &mom));
+
+ TensorShape lr_shape = ctx->InputShape("lr");
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
errors::InvalidArgument("lr is not a scalar: ",
lr_shape.DebugString()));
- TensorShape rho_shape = ctx->InputShape(4);
+ TensorShape rho_shape = ctx->InputShape("rho");
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho_shape),
errors::InvalidArgument("rho is not a scalar: ",
rho_shape.DebugString()));
- TensorShape momentum_shape = ctx->InputShape(5);
+ TensorShape momentum_shape = ctx->InputShape("momentum");
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum_shape),
errors::InvalidArgument("momentum is not a scalar: ",
momentum_shape.DebugString()));
- TensorShape epsilon_shape = ctx->InputShape(6);
+ TensorShape epsilon_shape = ctx->InputShape("epsilon");
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_shape),
errors::InvalidArgument("epsilon is not a scalar: ",
epsilon_shape.DebugString()));
- TensorShape grad_shape = ctx->InputShape(7);
+ TensorShape grad_shape = ctx->InputShape("grad");
// var should be the same shape as mom and ms.
OP_REQUIRES(ctx, var_shape.IsSameSize(ms_shape),
@@ -398,11 +551,11 @@ class ResourceApplyRMSProp : public XlaOpKernel {
"var and grad do not have the same shape",
var_shape.DebugString(), " ", grad_shape.DebugString()));
- xla::XlaOp lr = ctx->Input(3);
- xla::XlaOp rho = ctx->Input(4);
- xla::XlaOp momentum = ctx->Input(5);
- xla::XlaOp epsilon = ctx->Input(6);
- xla::XlaOp grad = ctx->Input(7);
+ xla::XlaOp lr = ctx->Input("lr");
+ xla::XlaOp rho = ctx->Input("rho");
+ xla::XlaOp momentum = ctx->Input("momentum");
+ xla::XlaOp epsilon = ctx->Input("epsilon");
+ xla::XlaOp grad = ctx->Input("grad");
// ms <- rho * ms_{t-1} + (1-rho) * grad * grad
// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)
@@ -421,20 +574,46 @@ class ResourceApplyRMSProp : public XlaOpKernel {
// ms <- grad**2 (1 - rho) + ms * rho
//
// Which is the equation listed above.
- xla::XlaOp new_ms =
- ms + (xla::Square(grad) - ms) * (xla::ScalarLike(ms, 1.0) - rho);
- xla::XlaOp new_mom =
- mom * momentum + grad * lr * xla::Rsqrt(new_ms + epsilon);
+ xla::XlaOp one = xla::ScalarLike(ms, 1.0);
+ xla::XlaOp new_ms = xla::Square(grad) * (one - rho) + ms * rho;
+ xla::XlaOp denominator;
+ if (centered_) {
+ mg = grad * (one - rho) + mg * rho;
+ denominator = new_ms - xla::Square(mg) + epsilon;
+ } else {
+ denominator = new_ms + epsilon;
+ }
+ xla::XlaOp new_mom = mom * momentum + grad * lr * xla::Rsqrt(denominator);
xla::XlaOp new_var = var - new_mom;
- OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, new_var));
- OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, new_ms));
- OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, type, new_mom));
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable("var", dtype_, new_var));
+ if (centered_) {
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable("mg", dtype_, mg));
+ }
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable("ms", dtype_, new_ms));
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable("mom", dtype_, new_mom));
}
+
+ protected:
+ bool centered_ = false;
+
+ private:
+ DataType dtype_;
};
REGISTER_XLA_OP(Name("ResourceApplyRMSProp").TypeConstraint("T", kFloatTypes),
ResourceApplyRMSProp);
+class ResourceApplyCenteredRMSProp : public ResourceApplyRMSProp {
+ public:
+ explicit ResourceApplyCenteredRMSProp(OpKernelConstruction* ctx)
+ : ResourceApplyRMSProp(ctx) {
+ centered_ = true;
+ }
+};
+REGISTER_XLA_OP(
+ Name("ResourceApplyCenteredRMSProp").TypeConstraint("T", kFloatTypes),
+ ResourceApplyCenteredRMSProp);
+
void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype,
bool has_l2_shrinkage) {
xla::XlaBuilder* b = ctx->builder();
@@ -640,5 +819,107 @@ class ResourceApplyAdadelta : public XlaOpKernel {
REGISTER_XLA_OP(Name("ResourceApplyAdadelta").TypeConstraint("T", kFloatTypes),
ResourceApplyAdadelta);
+class ResourceApplySignBase : public XlaOpKernel {
+ public:
+ explicit ResourceApplySignBase(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ TensorShape var_shape, m_shape;
+ xla::XlaOp var, m;
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var));
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype_, &m_shape, &m));
+ OP_REQUIRES(ctx, var_shape.IsSameSize(m_shape),
+ errors::InvalidArgument("var and m do not have the same shape",
+ var_shape.DebugString(), " ",
+ m_shape.DebugString()));
+ TensorShape grad_shape = ctx->InputShape(6);
+ OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
+ errors::InvalidArgument(
+ "var and grad do not have the same shape",
+ var_shape.DebugString(), " ", grad_shape.DebugString()));
+ CheckScalarParams(ctx);
+
+ xla::XlaOp lr = ctx->Input(2);
+ xla::XlaOp alpha = ctx->Input(3);
+ xla::XlaOp sign_decay = ctx->Input(4);
+ xla::XlaOp beta = ctx->Input(5);
+ xla::XlaOp grad = ctx->Input(6);
+
+ m = m * beta + grad * (xla::ScalarLike(beta, 1.0) - beta);
+ xla::XlaOp decay = xla::Sign(grad) * xla::Sign(m) * sign_decay;
+
+ xla::XlaOp grad_scale = ComputeGradientScale(alpha, decay);
+ var = var - lr * grad_scale * grad;
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var));
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, m));
+ }
+
+ virtual void CheckScalarParams(XlaOpKernelContext* ctx) {
+ TensorShape lr_shape = ctx->InputShape(2);
+ TensorShape sign_decay_shape = ctx->InputShape(4);
+ TensorShape beta_shape = ctx->InputShape(5);
+
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
+ errors::InvalidArgument("lr is not a scalar: ",
+ lr_shape.DebugString()));
+
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(sign_decay_shape),
+ errors::InvalidArgument("sign_decay is not a scalar: ",
+ sign_decay_shape.DebugString()));
+
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta_shape),
+ errors::InvalidArgument("beta is not a scalar: ",
+ beta_shape.DebugString()));
+ }
+
+ virtual xla::XlaOp ComputeGradientScale(xla::XlaOp alpha,
+ xla::XlaOp decay) = 0;
+
+ private:
+ DataType dtype_;
+};
+
+class ResourceApplyAddSign : public ResourceApplySignBase {
+ public:
+ explicit ResourceApplyAddSign(OpKernelConstruction* ctx)
+ : ResourceApplySignBase(ctx) {}
+
+ void CheckScalarParams(XlaOpKernelContext* ctx) override {
+ ResourceApplySignBase::CheckScalarParams(ctx);
+ TensorShape alpha_shape = ctx->InputShape(3);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape),
+ errors::InvalidArgument("alpha is not a scalar: ",
+ alpha_shape.DebugString()));
+ }
+
+ xla::XlaOp ComputeGradientScale(xla::XlaOp alpha, xla::XlaOp decay) override {
+ return alpha + decay;
+ }
+};
+REGISTER_XLA_OP(Name("ResourceApplyAddSign").TypeConstraint("T", kFloatTypes),
+ ResourceApplyAddSign);
+
+class ResourceApplyPowerSign : public ResourceApplySignBase {
+ public:
+ explicit ResourceApplyPowerSign(OpKernelConstruction* ctx)
+ : ResourceApplySignBase(ctx) {}
+
+ void CheckScalarParams(XlaOpKernelContext* ctx) override {
+ ResourceApplySignBase::CheckScalarParams(ctx);
+ TensorShape logbase_shape = ctx->InputShape(3);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(logbase_shape),
+ errors::InvalidArgument("logbase is not a scalar: ",
+ logbase_shape.DebugString()));
+ }
+
+ xla::XlaOp ComputeGradientScale(xla::XlaOp alpha, xla::XlaOp decay) override {
+ return xla::Exp(alpha * decay);
+ }
+};
+REGISTER_XLA_OP(Name("ResourceApplyPowerSign").TypeConstraint("T", kFloatTypes),
+ ResourceApplyPowerSign);
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc
index 0e5d58ecba..f951127bb9 100644
--- a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc
@@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc
index febac82873..bb27b5d56f 100644
--- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/types.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc
index 340165bac6..9413a30a6c 100644
--- a/tensorflow/compiler/tf2xla/kernels/while_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/op_kernel.h"
diff --git a/tensorflow/compiler/tf2xla/legacy_flags/backend_registration_flags.cc b/tensorflow/compiler/tf2xla/legacy_flags/backend_registration_flags.cc
new file mode 100644
index 0000000000..661505021f
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/legacy_flags/backend_registration_flags.cc
@@ -0,0 +1,63 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Legacy flags for the XLA bridge's backend registration modules.
+
+#include <mutex> // NOLINT
+#include <vector>
+
+#include "tensorflow/compiler/tf2xla/legacy_flags/backend_registration_flags.h"
+#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace tensorflow {
+namespace legacy_flags {
+
+// Pointers to the parsed value of the flags and flag descriptors, initialized
+// via flags_init.
+static BackendRegistrationFlags* flags;
+static std::vector<Flag>* flag_list;
+static std::once_flag flags_init;
+
+// Allocate *flags. Called via call_once(&flags_init,...).
+static void AllocateFlags() {
+ flags = new BackendRegistrationFlags;
+ flags->tf_enable_prng_ops_gpu = false;
+ flag_list = new std::vector<Flag>({
+ Flag("tf_enable_prng_ops_gpu", &flags->tf_enable_prng_ops_gpu,
+ "Whether to enable PRNG ops: [RandomStandardNormal | RandomUniform "
+ "| RandomUniformInt | TruncatedNormal] on GPU."),
+ });
+ xla::legacy_flags::ParseFlagsFromEnv(*flag_list);
+}
+
+// Append to *append_to flag definitions associated with the XLA bridge's
+// backend registration modules.
+void AppendBackendRegistrationFlags(std::vector<Flag>* append_to) {
+ std::call_once(flags_init, &AllocateFlags);
+ append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
+}
+
+// Return a pointer to the BackendRegistrationFlags struct;
+// repeated calls return the same pointer.
+// This should be called only after Flags::Parse() has returned.
+BackendRegistrationFlags* GetBackendRegistrationFlags() {
+ std::call_once(flags_init, &AllocateFlags);
+ return flags;
+}
+
+} // namespace legacy_flags
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/legacy_flags/backend_registration_flags.h b/tensorflow/compiler/tf2xla/legacy_flags/backend_registration_flags.h
new file mode 100644
index 0000000000..861c923dd5
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/legacy_flags/backend_registration_flags.h
@@ -0,0 +1,49 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_LEGACY_FLAGS_BACKEND_REGISTRATION_FLAGS_H_
+#define TENSORFLOW_COMPILER_TF2XLA_LEGACY_FLAGS_BACKEND_REGISTRATION_FLAGS_H_
+
+// Legacy flags for the XLA bridge's backend registration modules.
+
+#include <vector>
+
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace tensorflow {
+namespace legacy_flags {
+
+// Append to *flag_list flag definitions associated with the XLA bridge's
+// backend registration modules.
+void AppendBackendRegistrationFlags(std::vector<tensorflow::Flag>* append_to);
+
+// The values of flags associated with the XLA bridge's backend registration
+// module.
+typedef struct {
+ // Whether to enable RandomUniform op on GPU backend.
+ // TODO (b/32333178): Remove this flag or set its default to true.
+ bool tf_enable_prng_ops_gpu;
+} BackendRegistrationFlags;
+
+// Return a pointer to the BackendRegistrationFlags struct;
+// repeated calls return the same pointer.
+// This should be called only after Flags::Parse() has returned.
+BackendRegistrationFlags* GetBackendRegistrationFlags();
+
+} // namespace legacy_flags
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_LEGACY_FLAGS_BACKEND_REGISTRATION_FLAGS_H_
diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD
index dfa3c0595a..becc8b84fe 100644
--- a/tensorflow/compiler/tf2xla/lib/BUILD
+++ b/tensorflow/compiler/tf2xla/lib/BUILD
@@ -40,7 +40,7 @@ cc_library(
":triangular_solve",
":util",
":while_loop",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
@@ -67,13 +67,35 @@ cc_library(
)
cc_library(
+ name = "qr",
+ srcs = ["qr.cc"],
+ hdrs = ["qr.h"],
+ deps = [
+ ":batch_dot",
+ ":util",
+ ":while_loop",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/client/lib:constants",
+ "//tensorflow/compiler/xla/client/lib:math",
+ "//tensorflow/compiler/xla/client/lib:numeric",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
name = "scatter",
srcs = ["scatter.cc"],
hdrs = ["scatter.h"],
deps = [
":util",
":while_loop",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
@@ -92,7 +114,7 @@ cc_library(
deps = [
":batch_dot",
":util",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
@@ -111,7 +133,7 @@ xla_test(
deps = [
":triangular_solve",
"//tensorflow/compiler/xla:array2d",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
@@ -133,6 +155,7 @@ cc_library(
srcs = ["util.cc"],
hdrs = ["util.h"],
deps = [
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@@ -151,7 +174,7 @@ xla_test(
":batch_dot",
":util",
"//tensorflow/compiler/xla:array2d",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc
index f9f3a8c8cf..3c4eec081b 100644
--- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc
+++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc
@@ -84,7 +84,7 @@ xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x,
dimensions.push_back(y_shape.dimensions(y_outer_dim));
return xla::Broadcast(
xla::ConstantLiteral(builder,
- xla::Literal::Zero(x_shape.element_type())),
+ xla::LiteralUtil::Zero(x_shape.element_type())),
dimensions);
}
diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc
index cc840de393..35b137aa2c 100644
--- a/tensorflow/compiler/tf2xla/lib/cholesky.cc
+++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc
@@ -24,7 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
diff --git a/tensorflow/compiler/tf2xla/lib/qr.cc b/tensorflow/compiler/tf2xla/lib/qr.cc
new file mode 100644
index 0000000000..9c8ac7af25
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/lib/qr.cc
@@ -0,0 +1,387 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/lib/qr.h"
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/tf2xla/lib/batch_dot.h"
+#include "tensorflow/compiler/tf2xla/lib/util.h"
+#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
+#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/lib/math.h"
+#include "tensorflow/compiler/xla/client/lib/numeric.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+
+namespace {
+
+// Computes a Householder reflection of the form:
+// H = I - tau v v.T.
+// such that
+// H . ( x1 ) = ( x1 )
+// ( x2 ) = ( x2 )
+// ( ... ) = ( ... )
+// ( xk ) = ( beta )
+// ( ... ) ( 0 )
+// ( ... ) ( 0 )
+// Unlike the usual formulation, we allow the caller to supply 'k' rather than
+// only providing the relevant part of 'x' to maintain XLA's static shape
+// invariant. In addition, the implementation supports batching.
+// Pseudo-code, without batching:
+// alpha = x[k]
+// x_copy = np.copy(x)
+// x_copy[:k+1] = 0
+// xnorm = norm2(x_copy)
+// if xnorm == 0:
+// beta = alpha
+// tau = 0
+// v = np.zeros_like(x)
+// else:
+// beta = - np.sign(alpha) * dlapy2(alpha, xnorm)
+// tau = (beta - alpha) / beta
+// v = x / (alpha - beta)
+// v[k] = 1
+// return (v, tau, beta)
+// TODO(phawkins): LAPACK's xLARFG implementation has code for handling
+// overflows in the norm/beta calculations. Perhaps do the same here.
+xla::Status House(xla::XlaOp x, xla::XlaOp k, gtl::ArraySlice<int64> batch_dims,
+ const int64 m, xla::XlaOp* v, xla::XlaOp* tau,
+ xla::XlaOp* beta) {
+ xla::XlaBuilder* const builder = x.builder();
+ TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x));
+ const xla::PrimitiveType type = x_shape.element_type();
+
+ std::vector<int64> batch_dim_ids(batch_dims.size());
+ std::iota(batch_dim_ids.begin(), batch_dim_ids.end(), 0);
+ const int64 minor_dim = batch_dims.size();
+
+ xla::XlaOp zero = xla::ScalarLike(x, 0.0);
+ xla::XlaOp one = xla::ScalarLike(x, 1.0);
+
+ // alpha = x[k]
+ xla::XlaOp alpha =
+ xla::Reshape(DynamicSliceInMinorDims(x, {k}, {1}), batch_dims);
+
+ // Compute x[k+1:] (padded with zeros in elements 0..k)
+ xla::XlaOp iota = xla::Iota(builder, xla::S32, m);
+ xla::XlaOp x_after_k =
+ xla::Mul(x, xla::ConvertElementType(xla::Gt(iota, k), type),
+ /*broadcast_dimensions=*/{minor_dim});
+
+ // sigma = np.dot(x[k+1:], x[k+1:])
+ auto sigma =
+ xla::Reduce(x_after_k * x_after_k, zero,
+ xla::CreateScalarAddComputation(type, builder), {minor_dim});
+ // mu = np.sqrt(x[k]*x[k] + sigma)
+ auto mu = xla::Sqrt(xla::Square(alpha) + sigma);
+
+ auto sigma_is_zero = xla::Eq(sigma, zero);
+
+ *beta = xla::Select(sigma_is_zero, alpha, -xla::Sign(alpha) * mu);
+ *tau = xla::Select(sigma_is_zero, xla::Broadcast(zero, batch_dims),
+ (*beta - alpha) / *beta);
+ auto divisor = xla::Select(sigma_is_zero, xla::Broadcast(one, batch_dims),
+ alpha - *beta);
+
+ auto e_k = xla::Broadcast(xla::ConvertElementType(xla::Eq(iota, k), type),
+ std::vector<int64>(batch_dims.size(), 1));
+
+ // Form v as [0, 0, ..., 1] ++ x[k+1:] / divisor
+ // If sigma is zero, x[k+1:] is zero, so use any non-zero divisor.
+ *v = e_k +
+ xla::Div(x_after_k, divisor, /*broadcast_dimensions=*/batch_dim_ids);
+ return Status::OK();
+}
+
+// Householder QR decomposition. Algorithm 5.2.1 from Golub and Van
+// Loan "Matrix Computations", 4th Edition. This is an unblocked implementation
+// used as an inner routine of the blocked implementation.
+// Algorithm is adapted slightly so the shapes inside the loop are static, at
+// the cost of some redundant computation. Since this is used as an inner block
+// kernel, accumulates the Householder transformations (vs, taus) rather than
+// the matrix q.
+// Equivalent Python code, without batching:
+// def qr(a):
+// m = a.shape[0]
+// n = a.shape[1]
+// vs = np.zeros([m, n])
+// taus = np.zeros([n])
+// for j in xrange(min(m, n)):
+// v, tau, beta = house(a[:, j], j)
+// # Unusually, we apply the Householder transformation to the entirety of
+// # a, wasting FLOPs to maintain the static shape invariant that XLA
+// # requires. For columns that precede j this has no effect.
+// a[:, :] -= tau * np.dot(v[:, np.newaxis],
+// np.dot(v[np.newaxis, :], a[:, :]))
+// # Form column j explicitly rather than relying on the precision of the
+// # Householder update.
+// a[j, j] = beta
+// a[j+1:, j] = np.zeros([m - j - 1], dtype=a.dtype)
+// vs[:, j] = v
+// taus[j] = tau
+// return (q, vs, taus)
+struct QRBlockResult {
+ // The factored R value
+ xla::XlaOp r;
+
+ // Representation of the Householder matrices I - beta v v.T
+ xla::XlaOp taus; // Shape: [..., n]
+ xla::XlaOp vs; // Shape: [..., m, n]
+};
+xla::StatusOr<QRBlockResult> QRBlock(xla::XlaOp a) {
+ xla::XlaBuilder* builder = a.builder();
+ TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
+ const int num_dims = xla::ShapeUtil::Rank(a_shape);
+ if (num_dims < 2) {
+ return errors::InvalidArgument("Arguments to QR must have rank >= 2: ",
+ num_dims);
+ }
+ xla::PrimitiveType type = a_shape.element_type();
+
+ const int64 m = xla::ShapeUtil::GetDimension(a_shape, -2);
+ const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1);
+
+ const int64 num_batch_dims = num_dims - 2;
+ std::vector<int64> batch_dims(num_batch_dims);
+ for (int i = 0; i < num_batch_dims; ++i) {
+ batch_dims[i] = xla::ShapeUtil::GetDimension(a_shape, i);
+ }
+
+ std::vector<int64> batch_dim_indices(num_batch_dims);
+ std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
+
+ auto qr_body_fn =
+ [&](xla::XlaOp j, gtl::ArraySlice<xla::XlaOp> values,
+ xla::XlaBuilder* builder) -> xla::StatusOr<std::vector<xla::XlaOp>> {
+ auto a = values[0];
+ auto vs = values[1];
+ auto taus = values[2];
+
+ // v, beta = house(a[:, j], j)
+ auto x = DynamicSliceInMinorDims(a, {j}, {1});
+ xla::XlaOp v, tau, beta;
+ TF_RETURN_IF_ERROR(House(xla::Collapse(x, {num_dims - 2, num_dims - 1}), j,
+ batch_dims, m, &v, &tau, &beta));
+
+ std::vector<int64> shape = batch_dims;
+ shape.push_back(1);
+ shape.push_back(m);
+ auto v_broadcast = xla::Reshape(v, shape);
+ // a[:, :] -= tau * np.dot(v[:, np.newaxis],
+ // np.dot(v[np.newaxis, :], a[:, :]))
+ auto vva = BatchDot(v_broadcast, a);
+ vva = BatchDot(v_broadcast, vva, /*transpose_x=*/true);
+ a = a - xla::Mul(tau, vva,
+ /*broadcast_dimensions=*/batch_dim_indices);
+
+ // It is more precise to populate column 'k' explicitly, rather than
+ // computing it implicitly by applying the Householder transformation.
+ // a[k,k] = beta
+ // a[k+1:,k] = np.zeros([m-k-1], dtype=a.dtype)
+ auto iota = xla::Reshape(xla::Iota(a.builder(), xla::S32, m), {m, 1});
+ auto predecessor_mask = xla::ConvertElementType(xla::Lt(iota, j), type);
+ auto mask = xla::Broadcast(xla::ConvertElementType(xla::Eq(iota, j), type),
+ std::vector<int64>(batch_dims.size(), 1));
+ auto new_x =
+ xla::Mul(x, predecessor_mask,
+ /*broadcast_dimensions=*/{num_dims - 2, num_dims - 1}) +
+ xla::Mul(beta, mask, /*broadcast_dimensions=*/batch_dim_indices);
+ a = DynamicUpdateSliceInMinorDims(a, new_x, {j});
+
+ // vs[:, j] = v
+ vs = DynamicUpdateSliceInMinorDims(
+ vs, xla::Reshape(v, ConcatVectors(batch_dims, {m, 1})), {j});
+ // taus[j] = tau
+ taus = DynamicUpdateSliceInMinorDims(
+ taus, xla::Reshape(tau, ConcatVectors(batch_dims, {1})), {j});
+ return std::vector<xla::XlaOp>{a, vs, taus};
+ };
+
+ auto vs = xla::Zeros(builder, xla::ShapeUtil::MakeShape(
+ type, ConcatVectors(batch_dims, {m, n})));
+ auto taus = xla::Zeros(
+ builder, xla::ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {n})));
+
+ TF_ASSIGN_OR_RETURN(auto values,
+ XlaForEachIndex(std::min(m, n), xla::S32, qr_body_fn,
+ {a, vs, taus}, "qr", builder));
+
+ QRBlockResult result;
+ result.r = values[0];
+ result.vs = values[1];
+ result.taus = values[2];
+ return result;
+}
+
+// Computes W and Y such that I-WY is equivalent to the sequence of Householder
+// transformations given by vs and taus.
+// Golub and van Loan, "Matrix Computations", algorithm 5.1.2.
+// Y = np.zeros([m, n])
+// W = np.zeros([m, n])
+// Y[:, 0] = vs[:, 0]
+// W[:, 0] = -taus[0] * vs[:, 0]
+// for j in xrange(1, n):
+// v = vs[:, j]
+// z = -taus[j] * v - taus[j] * np.dot(W, np.dot(Y.T, v))
+// W[:, j] = z
+// Y[:, j] = v
+// return W
+// There is no need to return Y since at termination of the loop it is equal to
+// vs.
+xla::StatusOr<xla::XlaOp> ComputeWYRepresentation(
+ xla::PrimitiveType type, gtl::ArraySlice<int64> batch_dims, xla::XlaOp vs,
+ xla::XlaOp taus, int64 m, int64 n) {
+ std::vector<int64> batch_dim_indices(batch_dims.size());
+ std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
+ int64 n_index = batch_dims.size() + 1;
+
+ auto body_fn =
+ [&](xla::XlaOp j, gtl::ArraySlice<xla::XlaOp> values,
+ xla::XlaBuilder* builder) -> xla::StatusOr<std::vector<xla::XlaOp>> {
+ auto w = values[0];
+ auto y = values[1];
+ const auto vs = values[2];
+ const auto taus = values[3];
+
+ // Want j values in range [1, ... n).
+ j = j + xla::ConstantR0<int32>(builder, 1);
+ // vs has shape [..., m, 1]
+ auto v = DynamicSliceInMinorDims(vs, {j}, {1});
+ // beta has shape [..., 1]
+ auto beta = DynamicSliceInMinorDims(taus, {j}, {1});
+
+ // yv has shape [..., n, 1]
+ auto yv = BatchDot(y, v, /*transpose_x=*/true);
+ // wyv has shape [..., m, 1]
+ auto wyv = BatchDot(w, yv);
+
+ auto z = xla::Mul(
+ -beta, v + wyv,
+ /*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index}));
+
+ w = DynamicUpdateSliceInMinorDims(w, z, {j});
+ y = DynamicUpdateSliceInMinorDims(y, v, {j});
+
+ return std::vector<xla::XlaOp>{w, y, vs, taus};
+ };
+
+ xla::XlaBuilder* builder = vs.builder();
+ auto w = xla::Zeros(builder, xla::ShapeUtil::MakeShape(
+ type, ConcatVectors(batch_dims, {m, n})));
+ auto y = w;
+ auto v = SliceInMinorDims(vs, {0}, {1});
+ auto beta = SliceInMinorDims(taus, {0}, {1});
+ y = UpdateSliceInMinorDims(y, v, {0});
+ auto bv = xla::Mul(
+ -beta, v,
+ /*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index}));
+ w = UpdateSliceInMinorDims(w, bv, {0});
+
+ TF_ASSIGN_OR_RETURN(
+ auto values, XlaForEachIndex(n - 1, xla::S32, body_fn, {w, y, vs, taus},
+ "wy", builder));
+ return values[0];
+}
+
+} // namespace
+
+// Block Householder QR Factorization. Algorithm 5.2.2 of Golub and van Loan.
+// def qr_blocked(a, block_size):
+// m = a.shape[0]
+// n = a.shape[1]
+// q = np.eye(m)
+// for i in xrange(0, min(m, n), block_size):
+// k = min(block_size, min(m, n) - s)
+// (a, vs, taus) = qr(a[i:, i:i+k])
+// y = vs
+// w = ComputeWYRepresentation(vs, taus, m-i, k)
+// a[i:, i+r:] += np.dot(y, np.dot(w.T, a[i:, i+k:]))
+// q[:, i:] += np.dot(q[:, i:], np.dot(w, y.T))
+// return (q, a)
+// TODO(phawkins): consider using UT transformations (in the form I - V U V')
+// rather than WY transformations.
+xla::StatusOr<QRDecompositionResult> QRDecomposition(xla::XlaOp a,
+ int64 block_size) {
+ xla::XlaBuilder* builder = a.builder();
+ TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
+ const int num_dims = xla::ShapeUtil::Rank(a_shape);
+ if (num_dims < 2) {
+ return errors::InvalidArgument("Arguments to QR must have rank >= 2: ",
+ num_dims);
+ }
+ xla::PrimitiveType type = a_shape.element_type();
+
+ const int64 m = xla::ShapeUtil::GetDimension(a_shape, -2);
+ const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1);
+ const int64 p = std::min(m, n);
+
+ if (block_size < 1) {
+ return errors::InvalidArgument(
+ "block_size argument to QR must be >= 1; got ", block_size);
+ }
+
+ const int64 num_batch_dims = num_dims - 2;
+ std::vector<int64> batch_dims(num_batch_dims);
+ for (int i = 0; i < num_batch_dims; ++i) {
+ batch_dims[i] = xla::ShapeUtil::GetDimension(a_shape, i);
+ }
+
+ auto q = xla::Broadcast(xla::IdentityMatrix(builder, type, m, m), batch_dims);
+ for (int64 i = 0; i < p; i += block_size) {
+ int64 k = std::min(block_size, p - i);
+
+ auto a_block = SliceInMinorDims(a, {i, i}, {m, i + k});
+ TF_ASSIGN_OR_RETURN(auto qr_block, QRBlock(a_block));
+
+ a = UpdateSliceInMinorDims(a, qr_block.r, {i, i});
+
+ // Compute the I-WY block representation of a product of Householder
+ // matrices.
+ TF_ASSIGN_OR_RETURN(auto w,
+ ComputeWYRepresentation(type, batch_dims, qr_block.vs,
+ qr_block.taus, m - i, k));
+ auto y = qr_block.vs;
+
+ // a[i:, i+k:] += np.dot(Y, np.dot(W.T, a[i:, i+k:]))
+ auto a_panel = SliceInMinorDims(a, {i, i + k}, {m, n});
+ auto a_update = BatchDot(w, a_panel, /*transpose_x=*/true);
+ a_update = BatchDot(y, a_update);
+ a_panel = a_panel + a_update;
+ a = UpdateSliceInMinorDims(a, a_panel, {i, i + k});
+
+ // q[:, i:] += np.dot(np.dot(q[:, i:], W), Y.T))
+ auto q_panel = SliceInMinorDims(q, {0, i}, {m, m});
+ auto q_update = BatchDot(q_panel, w);
+ q_update =
+ BatchDot(q_update, y, /*transpose_x=*/false, /*transpose_y=*/true);
+ q_panel = q_panel + q_update;
+ q = UpdateSliceInMinorDims(q, q_panel, {0, i});
+ }
+ QRDecompositionResult result;
+ result.q = q;
+ result.r = a;
+ return result;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/qr.h b/tensorflow/compiler/tf2xla/lib/qr.h
new file mode 100644
index 0000000000..3aa6a9b075
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/lib/qr.h
@@ -0,0 +1,40 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_
+#define TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_
+
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+
+namespace tensorflow {
+
+// Computes the QR decompositions of a batch of matrices. That is,
+// given a (batched) matrix a, computes an orthonormal matrix Q and an
+// upper-triangular matrix R such that a = QR.
+// `a` must be a (batched) matrix of size [..., m, n].
+// The algorithm implements a blocked QR decomposition; `block_size` is
+// the block size to use.
+// TODO(phawkins): handle the complex case.
+struct QRDecompositionResult {
+ xla::XlaOp q;
+ xla::XlaOp r;
+};
+
+xla::StatusOr<QRDecompositionResult> QRDecomposition(xla::XlaOp a,
+ int64 block_size = 128);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_
diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc
index 85e3d3ab85..6a5be1c2be 100644
--- a/tensorflow/compiler/tf2xla/lib/scatter.cc
+++ b/tensorflow/compiler/tf2xla/lib/scatter.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
@@ -114,7 +114,7 @@ xla::StatusOr<xla::XlaOp> XlaScatter(
auto buffer = loop_vars[2];
auto zero_index = xla::ConstantLiteral(
- body_builder, xla::Literal::Zero(indices_shape.element_type()));
+ body_builder, xla::LiteralUtil::Zero(indices_shape.element_type()));
// Slice the i-th index from the indices array.
xla::XlaOp index;
diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
index 588afaac65..ce0f28db8f 100644
--- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
+++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc
index d5ffc1498e..f1bff6037b 100644
--- a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc
+++ b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc
index fdc8bfca49..a6f5d346cb 100644
--- a/tensorflow/compiler/tf2xla/lib/util.cc
+++ b/tensorflow/compiler/tf2xla/lib/util.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -28,6 +29,13 @@ limitations under the License.
namespace tensorflow {
+xla::XlaOp Zeros(xla::XlaBuilder* builder, const xla::Shape& shape) {
+ return xla::Broadcast(
+ xla::ConstantLiteral(builder,
+ xla::LiteralUtil::Zero(shape.element_type())),
+ xla::AsInt64Slice(shape.dimensions()));
+}
+
xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
double value) {
switch (type) {
@@ -56,31 +64,31 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
xla::Literal literal;
switch (type) {
case xla::U8:
- literal = std::move(*xla::Literal::CreateR0<uint8>(value));
+ literal = std::move(*xla::LiteralUtil::CreateR0<uint8>(value));
break;
case xla::U32:
- literal = std::move(*xla::Literal::CreateR0<uint32>(value));
+ literal = std::move(*xla::LiteralUtil::CreateR0<uint32>(value));
break;
case xla::U64:
- literal = std::move(*xla::Literal::CreateR0<uint64>(value));
+ literal = std::move(*xla::LiteralUtil::CreateR0<uint64>(value));
break;
case xla::S8:
- literal = std::move(*xla::Literal::CreateR0<int8>(value));
+ literal = std::move(*xla::LiteralUtil::CreateR0<int8>(value));
break;
case xla::S32:
- literal = std::move(*xla::Literal::CreateR0<int32>(value));
+ literal = std::move(*xla::LiteralUtil::CreateR0<int32>(value));
break;
case xla::S64:
- literal = std::move(*xla::Literal::CreateR0<int64>(value));
+ literal = std::move(*xla::LiteralUtil::CreateR0<int64>(value));
break;
case xla::F32:
- literal = std::move(*xla::Literal::CreateR0<float>(value));
+ literal = std::move(*xla::LiteralUtil::CreateR0<float>(value));
break;
case xla::F64:
- literal = std::move(*xla::Literal::CreateR0<double>(value));
+ literal = std::move(*xla::LiteralUtil::CreateR0<double>(value));
break;
case xla::C64:
- literal = std::move(*xla::Literal::CreateR0<complex64>(value));
+ literal = std::move(*xla::LiteralUtil::CreateR0<complex64>(value));
break;
case xla::PRED:
LOG(FATAL) << "pred element type is not integral";
@@ -89,11 +97,11 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
LOG(FATAL) << "u16/s16 literals not yet implemented";
case xla::BF16:
literal = std::move(
- *xla::Literal::CreateR0<bfloat16>(static_cast<bfloat16>(value)));
+ *xla::LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(value)));
break;
case xla::F16:
- literal = std::move(
- *xla::Literal::CreateR0<xla::half>(static_cast<xla::half>(value)));
+ literal = std::move(*xla::LiteralUtil::CreateR0<xla::half>(
+ static_cast<xla::half>(value)));
break;
case xla::TUPLE:
LOG(FATAL) << "tuple element type is not integral";
diff --git a/tensorflow/compiler/tf2xla/lib/util_test.cc b/tensorflow/compiler/tf2xla/lib/util_test.cc
index 7d0f2222a9..442fe92c34 100644
--- a/tensorflow/compiler/tf2xla/lib/util_test.cc
+++ b/tensorflow/compiler/tf2xla/lib/util_test.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/lib/batch_dot.h"
#include "tensorflow/compiler/xla/array2d.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.cc b/tensorflow/compiler/tf2xla/lib/while_loop.cc
index 7cc88f34d2..574e70ddee 100644
--- a/tensorflow/compiler/tf2xla/lib/while_loop.cc
+++ b/tensorflow/compiler/tf2xla/lib/while_loop.cc
@@ -100,8 +100,9 @@ xla::StatusOr<std::vector<xla::XlaOp>> XlaForEachIndex(
std::vector<xla::XlaOp> updated_values;
updated_values.reserve(values.size());
updated_values.push_back(xla::Add(
- iteration, xla::ConstantLiteral(
- body_builder, xla::Literal::One(num_iterations_type))));
+ iteration,
+ xla::ConstantLiteral(body_builder,
+ xla::LiteralUtil::One(num_iterations_type))));
values.remove_prefix(1);
TF_ASSIGN_OR_RETURN(std::vector<xla::XlaOp> body_outputs,
@@ -113,8 +114,8 @@ xla::StatusOr<std::vector<xla::XlaOp>> XlaForEachIndex(
std::vector<xla::XlaOp> values;
values.reserve(initial_values.size() + 1);
- values.push_back(
- xla::ConstantLiteral(builder, xla::Literal::Zero(num_iterations_type)));
+ values.push_back(xla::ConstantLiteral(
+ builder, xla::LiteralUtil::Zero(num_iterations_type)));
values.insert(values.end(), initial_values.begin(), initial_values.end());
TF_ASSIGN_OR_RETURN(values, XlaWhileLoop(while_cond_fn, while_body_fn, values,
diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc
index b43405a1a4..2fb66913ad 100644
--- a/tensorflow/compiler/tf2xla/literal_util.cc
+++ b/tensorflow/compiler/tf2xla/literal_util.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
namespace tensorflow {
diff --git a/tensorflow/compiler/tf2xla/literal_util.h b/tensorflow/compiler/tf2xla/literal_util.h
index ab7e861f33..0610a57029 100644
--- a/tensorflow/compiler/tf2xla/literal_util.h
+++ b/tensorflow/compiler/tf2xla/literal_util.h
@@ -18,7 +18,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_
#define TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
diff --git a/tensorflow/compiler/tf2xla/literal_util_test.cc b/tensorflow/compiler/tf2xla/literal_util_test.cc
index f3d6787daa..a3404c2b3d 100644
--- a/tensorflow/compiler/tf2xla/literal_util_test.cc
+++ b/tensorflow/compiler/tf2xla/literal_util_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/tensor_testutil.h"
@@ -27,7 +28,7 @@ TEST(LiteralUtil, LiteralToHostTensor) {
{
std::vector<int64> int64_values = {1, 2, 3};
std::unique_ptr<xla::Literal> int64_values_literal =
- xla::Literal::CreateR1(gtl::ArraySlice<int64>(int64_values));
+ xla::LiteralUtil::CreateR1(gtl::ArraySlice<int64>(int64_values));
Tensor host_tensor;
EXPECT_EQ("Cannot convert literal of type S64 to tensor of type int32",
LiteralToHostTensor(*int64_values_literal, DT_INT32, &host_tensor)
@@ -48,7 +49,7 @@ TEST(LiteralUtil, LiteralToHostTensor) {
Tensor host_tensor;
std::vector<int32> int32_values = {10, 11};
std::unique_ptr<xla::Literal> int32_values_literal =
- xla::Literal::CreateR1(gtl::ArraySlice<int32>(int32_values));
+ xla::LiteralUtil::CreateR1(gtl::ArraySlice<int32>(int32_values));
EXPECT_TRUE(
LiteralToHostTensor(*int32_values_literal, DT_INT32, &host_tensor)
.ok());
diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc
index 84c133ffab..f0b30dcf4e 100644
--- a/tensorflow/compiler/tf2xla/tf2xla_test.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/framework/attr_value.pb.h"
@@ -73,8 +74,8 @@ TEST(ConvertGraphDefToXla, Sum) {
TF_EXPECT_OK(ConvertGraphDefToXla(graph_def, config, client, &computation));
// Set up arguments.
- auto x_literal = xla::Literal::CreateR0<int32>(10);
- auto y_literal = xla::Literal::CreateR0<int32>(32);
+ auto x_literal = xla::LiteralUtil::CreateR0<int32>(10);
+ auto y_literal = xla::LiteralUtil::CreateR0<int32>(32);
auto x_global_or = client->TransferToServer(*x_literal);
auto y_global_or = client->TransferToServer(*y_literal);
TF_EXPECT_OK(x_global_or.status());
diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
index 07af8ef54b..6f76816a86 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
@@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
@@ -206,9 +206,9 @@ TEST_F(XlaCompilerTest, Simple) {
// Tests that the generated computation works.
std::unique_ptr<xla::Literal> param0_literal =
- xla::Literal::CreateR1<int32>({7, 42});
+ xla::LiteralUtil::CreateR1<int32>({7, 42});
std::unique_ptr<xla::Literal> param1_literal =
- xla::Literal::CreateR1<int32>({-3, 101});
+ xla::LiteralUtil::CreateR1<int32>({-3, 101});
std::unique_ptr<xla::GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
@@ -222,9 +222,9 @@ TEST_F(XlaCompilerTest, Simple) {
client_->Transfer(*actual).ConsumeValueOrDie();
std::unique_ptr<xla::Literal> expected0 =
- xla::Literal::CreateR1<int32>({4, 143});
+ xla::LiteralUtil::CreateR1<int32>({4, 143});
std::unique_ptr<xla::Literal> expected_literal =
- xla::Literal::MakeTuple({expected0.get()});
+ xla::LiteralUtil::MakeTuple({expected0.get()});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
@@ -306,7 +306,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
// Tests that the generated computation works.
std::unique_ptr<xla::Literal> param0_literal =
- xla::Literal::CreateR1<int32>({7, 42});
+ xla::LiteralUtil::CreateR1<int32>({7, 42});
std::unique_ptr<xla::GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
@@ -317,9 +317,9 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
client_->Transfer(*actual).ConsumeValueOrDie();
std::unique_ptr<xla::Literal> expected0 =
- xla::Literal::CreateR1<int32>({-7, -42});
+ xla::LiteralUtil::CreateR1<int32>({-7, -42});
std::unique_ptr<xla::Literal> expected_literal =
- xla::Literal::MakeTuple({expected0.get()});
+ xla::LiteralUtil::MakeTuple({expected0.get()});
EXPECT_TRUE(
xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
@@ -341,7 +341,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
// Tests that the generated computation works.
std::unique_ptr<xla::Literal> param0_literal =
- xla::Literal::CreateR1<int32>({7, 42});
+ xla::LiteralUtil::CreateR1<int32>({7, 42});
std::unique_ptr<xla::GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
@@ -351,11 +351,12 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
std::unique_ptr<xla::Literal> actual_literal =
client_->Transfer(*actual).ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> expected0 = xla::Literal::CreateR0<int32>(7);
+ std::unique_ptr<xla::Literal> expected0 =
+ xla::LiteralUtil::CreateR0<int32>(7);
std::unique_ptr<xla::Literal> expected1 =
- xla::Literal::CreateR1<int32>({-7, -42});
+ xla::LiteralUtil::CreateR1<int32>({-7, -42});
std::unique_ptr<xla::Literal> expected =
- xla::Literal::MakeTuple({expected0.get(), expected1.get()});
+ xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected, *actual_literal));
}
}
@@ -569,11 +570,11 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) {
// Tests that the generated computation works.
std::unique_ptr<xla::Literal> input_base =
- xla::Literal::CreateR1<int32>({7, 42});
+ xla::LiteralUtil::CreateR1<int32>({7, 42});
std::unique_ptr<xla::Literal> input_grad2 =
- xla::Literal::CreateR1<int32>({-3, 101});
+ xla::LiteralUtil::CreateR1<int32>({-3, 101});
std::unique_ptr<xla::Literal> input =
- xla::Literal::MakeTuple({input_base.get(), input_grad2.get()});
+ xla::LiteralUtil::MakeTuple({input_base.get(), input_grad2.get()});
std::unique_ptr<xla::GlobalData> param0_data =
client_->TransferToServer(*input).ConsumeValueOrDie();
@@ -583,17 +584,18 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) {
std::unique_ptr<xla::Literal> actual_literal =
client_->Transfer(*actual).ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> output_read = xla::Literal::CreateR0<int32>(42);
+ std::unique_ptr<xla::Literal> output_read =
+ xla::LiteralUtil::CreateR0<int32>(42);
std::unique_ptr<xla::Literal> output_base =
- xla::Literal::CreateR1<int32>({7, 42});
+ xla::LiteralUtil::CreateR1<int32>({7, 42});
std::unique_ptr<xla::Literal> output_grad1 =
- xla::Literal::CreateR1<int32>({0, 1});
+ xla::LiteralUtil::CreateR1<int32>({0, 1});
std::unique_ptr<xla::Literal> output_grad2 =
- xla::Literal::CreateR1<int32>({-3, 101});
- std::unique_ptr<xla::Literal> output_resource = xla::Literal::MakeTuple(
+ xla::LiteralUtil::CreateR1<int32>({-3, 101});
+ std::unique_ptr<xla::Literal> output_resource = xla::LiteralUtil::MakeTuple(
{output_base.get(), output_grad1.get(), output_grad2.get()});
std::unique_ptr<xla::Literal> expected_literal =
- xla::Literal::MakeTuple({output_read.get(), output_resource.get()});
+ xla::LiteralUtil::MakeTuple({output_read.get(), output_resource.get()});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
@@ -796,9 +798,9 @@ TEST_F(XlaCompilerTest, Variables) {
// Tests that the generated computation works.
std::unique_ptr<xla::Literal> param0_literal =
- xla::Literal::CreateR1<int32>({7, 42});
+ xla::LiteralUtil::CreateR1<int32>({7, 42});
std::unique_ptr<xla::Literal> param1_literal =
- xla::Literal::CreateR1<int32>({-3, 101});
+ xla::LiteralUtil::CreateR1<int32>({-3, 101});
std::unique_ptr<xla::GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
@@ -812,11 +814,11 @@ TEST_F(XlaCompilerTest, Variables) {
client_->Transfer(*actual).ConsumeValueOrDie();
std::unique_ptr<xla::Literal> expected0 =
- xla::Literal::CreateR1<int32>({5, 144});
+ xla::LiteralUtil::CreateR1<int32>({5, 144});
std::unique_ptr<xla::Literal> expected1 =
- xla::Literal::CreateR1<int32>({4, 143});
+ xla::LiteralUtil::CreateR1<int32>({4, 143});
std::unique_ptr<xla::Literal> expected_literal =
- xla::Literal::MakeTuple({expected0.get(), expected1.get()});
+ xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
@@ -884,9 +886,9 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
// Tests that the generated computation works.
std::unique_ptr<xla::Literal> param0_literal =
- xla::Literal::CreateR2<int32>({{4, 55}, {1, -3}});
+ xla::LiteralUtil::CreateR2<int32>({{4, 55}, {1, -3}});
std::unique_ptr<xla::Literal> param1_literal =
- xla::Literal::CreateR1<int32>({22, 11, 33, 404});
+ xla::LiteralUtil::CreateR1<int32>({22, 11, 33, 404});
std::unique_ptr<xla::GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
@@ -900,11 +902,11 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
client_->Transfer(*actual).ConsumeValueOrDie();
std::unique_ptr<xla::Literal> expected0 =
- xla::Literal::CreateR2<int32>({{27, 67}, {35, 402}});
+ xla::LiteralUtil::CreateR2<int32>({{27, 67}, {35, 402}});
std::unique_ptr<xla::Literal> expected1 =
- xla::Literal::CreateR1<int32>({26, 66, 34, 401});
+ xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
std::unique_ptr<xla::Literal> expected_literal =
- xla::Literal::MakeTuple({expected0.get(), expected1.get()});
+ xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
@@ -953,9 +955,9 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) {
// Tests that the generated computation works.
std::unique_ptr<xla::Literal> param0_literal =
- xla::Literal::CreateR1<int32>({4, 55, 1, -3});
+ xla::LiteralUtil::CreateR1<int32>({4, 55, 1, -3});
std::unique_ptr<xla::Literal> param1_literal =
- xla::Literal::CreateR1<int32>({22, 11, 33, 404});
+ xla::LiteralUtil::CreateR1<int32>({22, 11, 33, 404});
std::unique_ptr<xla::GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
@@ -969,11 +971,11 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) {
client_->Transfer(*actual).ConsumeValueOrDie();
std::unique_ptr<xla::Literal> expected0 =
- xla::Literal::CreateR1<int32>({27, 67, 35, 402});
+ xla::LiteralUtil::CreateR1<int32>({27, 67, 35, 402});
std::unique_ptr<xla::Literal> expected1 =
- xla::Literal::CreateR1<int32>({26, 66, 34, 401});
+ xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
std::unique_ptr<xla::Literal> expected_literal =
- xla::Literal::MakeTuple({expected0.get(), expected1.get()});
+ xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc
index fd39a58ce6..0dea366476 100644
--- a/tensorflow/compiler/tf2xla/xla_context.cc
+++ b/tensorflow/compiler/tf2xla/xla_context.cc
@@ -27,7 +27,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
diff --git a/tensorflow/compiler/tf2xla/xla_cpu_backend.cc b/tensorflow/compiler/tf2xla/xla_cpu_backend.cc
index ead229aacc..23d04d43b3 100644
--- a/tensorflow/compiler/tf2xla/xla_cpu_backend.cc
+++ b/tensorflow/compiler/tf2xla/xla_cpu_backend.cc
@@ -31,6 +31,10 @@ bool CpuOpFilter(KernelDef* kdef) {
DT_FLOAT);
return true;
}
+ // TODO(b/26783907): The CPU backend currently does not implement sort.
+ if (kdef->op() == "XlaSort" || kdef->op() == "TopKV2") {
+ return false;
+ }
if (kdef->op() == "Const") {
AddDtypeToKernalDefConstraint("dtype", DT_STRING, kdef);
}
diff --git a/tensorflow/compiler/tf2xla/xla_gpu_backend.cc b/tensorflow/compiler/tf2xla/xla_gpu_backend.cc
index 62168b6483..dc98d4fda6 100644
--- a/tensorflow/compiler/tf2xla/xla_gpu_backend.cc
+++ b/tensorflow/compiler/tf2xla/xla_gpu_backend.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/compiler/tf2xla/legacy_flags/backend_registration_flags.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/framework/kernel_def.pb.h"
@@ -22,8 +23,16 @@ namespace tensorflow {
bool GpuOpFilter(KernelDef* kdef) {
// TODO(b/31361304): The GPU backend does not parallelize PRNG ops, leading to
// slow code.
- if (kdef->op() == "RandomStandardNormal" || kdef->op() == "RandomUniform" ||
- kdef->op() == "RandomUniformInt" || kdef->op() == "TruncatedNormal") {
+ legacy_flags::BackendRegistrationFlags* flags =
+ legacy_flags::GetBackendRegistrationFlags();
+ VLOG(2) << "flags->tf_enable_prng_ops_gpu: " << flags->tf_enable_prng_ops_gpu;
+ if (!flags->tf_enable_prng_ops_gpu &&
+ (kdef->op() == "RandomStandardNormal" || kdef->op() == "RandomUniform" ||
+ kdef->op() == "RandomUniformInt" || kdef->op() == "TruncatedNormal")) {
+ return false;
+ }
+ // TODO(b/26783907): The GPU backend currently does not implement sort.
+ if (kdef->op() == "XlaSort" || kdef->op() == "TopKV2") {
return false;
}
if (kdef->op() == "Const") {
diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc
index edbc5e95a8..4d1b3b1a13 100644
--- a/tensorflow/compiler/tf2xla/xla_helpers.cc
+++ b/tensorflow/compiler/tf2xla/xla_helpers.cc
@@ -94,13 +94,13 @@ xla::XlaOp ArgMinMax(xla::XlaOp input, xla::PrimitiveType output_type, int axis,
xla::XlaOp XlaHelpers::Zero(xla::XlaBuilder* b, DataType data_type) {
xla::PrimitiveType type;
TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
- return xla::ConstantLiteral(b, xla::Literal::Zero(type));
+ return xla::ConstantLiteral(b, xla::LiteralUtil::Zero(type));
}
xla::XlaOp XlaHelpers::One(xla::XlaBuilder* b, DataType data_type) {
xla::PrimitiveType type;
TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
- return xla::ConstantLiteral(b, xla::Literal::One(type));
+ return xla::ConstantLiteral(b, xla::LiteralUtil::One(type));
}
xla::XlaOp XlaHelpers::IntegerLiteral(xla::XlaBuilder* b, DataType data_type,
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index 359cb4c467..e8eafb3819 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -66,10 +66,18 @@ const xla::XlaOp& XlaOpKernelContext::Input(int index) {
return GetComputationFromTensor(context_->input(index));
}
+const xla::XlaOp& XlaOpKernelContext::Input(StringPiece name) {
+ return GetComputationFromTensor(GetInputTensorByName(name));
+}
+
TensorShape XlaOpKernelContext::InputShape(int index) {
return context_->input(index).shape();
}
+TensorShape XlaOpKernelContext::InputShape(StringPiece name) {
+ return GetInputTensorByName(name).shape();
+}
+
DataType XlaOpKernelContext::input_type(int index) const {
return context_->input(index).dtype();
}
@@ -332,10 +340,11 @@ Status XlaOpKernelContext::ConstantInputList(
return Status::OK();
}
-Status XlaOpKernelContext::ReadVariableInput(int index, DataType type,
- TensorShape* shape,
- xla::XlaOp* value) {
- const Tensor& tensor = context_->input(index);
+namespace {
+
+Status ReadVariableInputTensor(const Tensor& tensor, DataType type,
+ const OpKernelContext* ctx, TensorShape* shape,
+ xla::XlaOp* value) {
const XlaExpression* expression = CastExpressionFromTensor(tensor);
XlaResource* variable = expression->resource();
TF_RET_CHECK(variable != nullptr);
@@ -353,7 +362,7 @@ Status XlaOpKernelContext::ReadVariableInput(int index, DataType type,
*shape = variable->shape();
}
- XlaContext& xla_context = XlaContext::Get(context_);
+ XlaContext& xla_context = XlaContext::Get(ctx);
TF_ASSIGN_OR_RETURN(
TensorShape representation_shape,
xla_context.RepresentationShape(variable->shape(), variable->type()));
@@ -365,6 +374,22 @@ Status XlaOpKernelContext::ReadVariableInput(int index, DataType type,
return Status::OK();
}
+} // namespace
+
+Status XlaOpKernelContext::ReadVariableInput(int index, DataType type,
+ TensorShape* shape,
+ xla::XlaOp* value) {
+ return ReadVariableInputTensor(context_->input(index), type, context_, shape,
+ value);
+}
+
+Status XlaOpKernelContext::ReadVariableInput(StringPiece name, DataType type,
+ TensorShape* shape,
+ xla::XlaOp* value) {
+ return ReadVariableInputTensor(GetInputTensorByName(name), type, context_,
+ shape, value);
+}
+
Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type,
TensorShape* shape) const {
const Tensor& tensor = context_->input(index);
@@ -455,17 +480,17 @@ Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) {
return Status::OK();
}
-Status XlaOpKernelContext::AssignVariable(int input_index, DataType type,
- xla::XlaOp handle) {
- TF_RET_CHECK(handle.valid());
+namespace {
- const XlaExpression* expression =
- CastExpressionFromTensor(context_->input(input_index));
+Status AssignVariableTensor(const Tensor& tensor, DataType type,
+ const OpKernelContext* ctx, xla::XlaOp handle,
+ xla::XlaBuilder* builder) {
+ const XlaExpression* expression = CastExpressionFromTensor(tensor);
XlaResource* variable = expression->resource();
TF_RET_CHECK(variable != nullptr);
TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
- auto shape_or_status = builder()->GetShape(handle);
+ auto shape_or_status = builder->GetShape(handle);
if (!shape_or_status.ok()) {
return shape_or_status.status();
}
@@ -475,7 +500,7 @@ Status XlaOpKernelContext::AssignVariable(int input_index, DataType type,
TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape));
- XlaContext& xla_context = XlaContext::Get(context_);
+ XlaContext& xla_context = XlaContext::Get(ctx);
TF_ASSIGN_OR_RETURN(TensorShape representation_shape,
xla_context.RepresentationShape(shape, type));
if (shape != representation_shape) {
@@ -484,6 +509,22 @@ Status XlaOpKernelContext::AssignVariable(int input_index, DataType type,
return variable->SetValue(handle);
}
+} // namespace
+
+Status XlaOpKernelContext::AssignVariable(int input_index, DataType type,
+ xla::XlaOp handle) {
+ TF_RET_CHECK(handle.valid());
+ return AssignVariableTensor(context_->input(input_index), type, context_,
+ handle, builder());
+}
+
+Status XlaOpKernelContext::AssignVariable(StringPiece name, DataType type,
+ xla::XlaOp handle) {
+ TF_RET_CHECK(handle.valid());
+ return AssignVariableTensor(GetInputTensorByName(name), type, context_,
+ handle, builder());
+}
+
XlaCompiler* XlaOpKernelContext::compiler() const {
return XlaContext::Get(context_).compiler();
}
@@ -523,6 +564,12 @@ const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMul(
return XlaContext::Get(context_).GetOrCreateMul(type);
}
+const Tensor& XlaOpKernelContext::GetInputTensorByName(StringPiece name) {
+ const Tensor* tensor;
+ CHECK(context_->input(name, &tensor).ok());
+ return *tensor;
+}
+
XlaOpKernel::XlaOpKernel(OpKernelConstruction* context) : OpKernel(context) {}
void XlaOpKernel::Compute(OpKernelContext* context) {
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h
index 2bde2c983d..6203cffd80 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.h
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h
@@ -67,21 +67,26 @@ class XlaOpKernelContext {
// Returns the number of inputs to the operator.
int num_inputs() const { return context_->num_inputs(); }
- // Returns the type of input 'index'.
+ // Returns the type of input `index`.
DataType input_type(int index) const;
- // Returns the type of input 'index' as an xla::PrimitiveType. If the type
+ // Returns the type of input `index` as an xla::PrimitiveType. If the type
// is not representable as an XLA type, sets an error status and returns
// xla::PRIMITIVE_TYPE_INVALID.
xla::PrimitiveType input_xla_type(int index);
- // Returns the shape of input 'index'.
+ // Returns the shape of input `index`.
TensorShape InputShape(int index);
- // Returns input 'index' as a XlaOp. Unlike
+ // Returns the shape of input `name`.
+ TensorShape InputShape(StringPiece name);
+
+ // Returns input `index` as a XlaOp. Unlike
// OpKernelContext::Input returns a symbolic value rather than a concrete
// Tensor.
const xla::XlaOp& Input(int index);
+ // Returns input `name` as a XlaOp.
+ const xla::XlaOp& Input(StringPiece name);
// Returns true if all inputs are the same shape, otherwise sets the
// status to a non-OK value and returns false.
@@ -96,13 +101,13 @@ class XlaOpKernelContext {
// Helper methods for constant inputs.
- // Evaluates input 'index' and stores it in '*constant_literal'. If the
+ // Evaluates input `index` and stores it in `*constant_literal`. If the
// expression cannot be evaluated, e.g., because it depends on unbound
// parameters, returns a non-OK status.
Status ConstantInput(int index, xla::Literal* constant_literal);
- // Evaluates input 'index', reshapes it to 'new_shape' if new_shape !=
- // InputShape(index), and stores it in '*constant_literal'. If the input
+ // Evaluates input `index`, reshapes it to `new_shape` if new_shape !=
+ // InputShape(index), and stores it in `*constant_literal`. If the input
// cannot be evaluated, e.g., because it depends on unbound parameters,
// returns a non-Ok status. If InputShape(index).num_elements() !=
// new_shape.num_elements(), returns an error status.
@@ -137,17 +142,17 @@ class XlaOpKernelContext {
return context_->expected_output_dtype(index);
}
- // Sets output 'index' to the XlaOp 'handle'.
+ // Sets output `index` to the XlaOp `handle`.
// All outputs should be set using SetOutput and SetConstantOutput, not
// via the underlying OpKernelContext.
void SetOutput(int index, const xla::XlaOp& handle);
- // Sets output 'index' to compile-time constant 'host_tensor', where
- // 'host_tensor' is a tensor in host memory. It is preferable to use
+ // Sets output `index` to compile-time constant `host_tensor`, where
+ // `host_tensor` is a tensor in host memory. It is preferable to use
// SetConstantOutput where possible.
void SetConstantOutput(int index, const Tensor& host_tensor);
- // Sets output 'index' to an invalid value.
+ // Sets output `index` to an invalid value.
// Any subsequent attempt to consume this output will cause an error.
void SetInvalidOutput(int index);
@@ -157,10 +162,10 @@ class XlaOpKernelContext {
// Variables
- // Sets '*resource' to the resource associated with input `index`.
+ // Sets `*resource` to the resource associated with input `index`.
Status GetResourceInput(int index, XlaResource** resource);
- // Sets output 'index' to be a reference to resource 'resource'.
+ // Sets output `index` to be a reference to resource `resource`.
void SetResourceOutput(int index, XlaResource* resource);
// Sets `*type` and `*shape` to the current type and shape of a variable's
@@ -169,17 +174,23 @@ class XlaOpKernelContext {
TensorShape* shape) const;
// Reads the current value of the resouce variable referred to by input
- // 'index'. If `shape` is not nullptr, sets `*shape` to the shape of the
+ // `index`. If `shape` is not nullptr, sets `*shape` to the shape of the
// variable. Returns an error if the variable has not been initialized, or if
// its type does not match `type`.
Status ReadVariableInput(int index, DataType type, TensorShape* shape,
xla::XlaOp* value);
+ // Reads the current value of the resouce variable referred to by input
+ // `name`.
+ Status ReadVariableInput(StringPiece name, DataType type, TensorShape* shape,
+ xla::XlaOp* value);
// Assigns the value `handle` to the variable referenced by input
// `input_index`. The variable must be of `type`. Returns an error if the
// variable has been initialized with a different type or with a
// different shape.
Status AssignVariable(int input_index, DataType type, xla::XlaOp handle);
+ // Assigns the value `handle` to the variable referenced by input `name`.
+ Status AssignVariable(StringPiece name, DataType type, xla::XlaOp handle);
// Helper routines for the OP_REQUIRES macros
void CtxFailure(const Status& s);
@@ -227,6 +238,9 @@ class XlaOpKernelContext {
const xla::XlaComputation* GetOrCreateMul(const DataType type);
private:
+ // Returns the tensor of input `name`.
+ const Tensor& GetInputTensorByName(StringPiece name);
+
OpKernelContext* const context_;
};
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index 03e542855b..f1c383fd9e 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -254,6 +254,7 @@ tf_cc_test(
":types",
":util",
":xla_data_proto",
+ "//tensorflow/core:lib",
"//tensorflow/core:test_main",
],
)
@@ -281,9 +282,9 @@ tf_cc_test(
)
cc_library(
- name = "literal_util",
- srcs = ["literal_util.cc"],
- hdrs = ["literal_util.h"],
+ name = "literal",
+ srcs = ["literal.cc"],
+ hdrs = ["literal.h"],
visibility = ["//visibility:public"],
deps = [
":array2d",
@@ -300,11 +301,12 @@ cc_library(
)
tf_cc_test(
- name = "literal_util_test",
- srcs = ["literal_util_test.cc"],
+ name = "literal_test",
+ srcs = ["literal_test.cc"],
deps = [
":array3d",
":array4d",
+ ":literal",
":literal_util",
":shape_util",
":test",
@@ -317,6 +319,26 @@ tf_cc_test(
)
cc_library(
+ name = "literal_util",
+ srcs = ["literal_util.cc"],
+ hdrs = ["literal_util.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":array2d",
+ ":array3d",
+ ":array4d",
+ ":literal",
+ ":shape_util",
+ ":sparse_index_array",
+ ":status_macros",
+ ":types",
+ ":util",
+ ":xla_data_proto",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
name = "error_spec",
hdrs = ["error_spec.h"],
)
@@ -327,6 +349,7 @@ cc_library(
hdrs = ["literal_comparison.h"],
deps = [
":error_spec",
+ ":literal",
":literal_util",
":util",
"//tensorflow/core:lib",
@@ -458,7 +481,7 @@ cc_library(
hdrs = ["packed_literal_reader.h"],
visibility = [":internal"],
deps = [
- ":literal_util",
+ ":literal",
":shape_util",
":status_macros",
":statusor",
@@ -489,7 +512,7 @@ cc_library(
hdrs = ["text_literal_reader.h"],
visibility = [":internal"],
deps = [
- ":literal_util",
+ ":literal",
":shape_util",
":status_macros",
":statusor",
@@ -505,7 +528,7 @@ tf_cc_test(
name = "text_literal_reader_test",
srcs = ["text_literal_reader_test.cc"],
deps = [
- ":literal_util",
+ ":literal",
":shape_util",
":test",
":text_literal_reader",
@@ -522,7 +545,7 @@ cc_library(
hdrs = ["text_literal_writer.h"],
visibility = [":internal"],
deps = [
- ":literal_util",
+ ":literal",
":shape_util",
":status_macros",
":types",
@@ -535,6 +558,7 @@ tf_cc_test(
name = "text_literal_writer_test",
srcs = ["text_literal_writer_test.cc"],
deps = [
+ ":literal",
":literal_util",
":test",
":test_helpers",
@@ -607,6 +631,7 @@ cc_library(
":array2d",
":array3d",
":array4d",
+ ":literal_util",
":util",
":window_util",
":xla_data_proto",
@@ -627,7 +652,7 @@ tf_cc_test(
":array2d",
":array3d",
":array4d",
- ":literal_util",
+ ":literal",
":reference_util",
":test",
":util",
diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD
index 8f08d3b2e0..25666cad40 100644
--- a/tensorflow/compiler/xla/client/BUILD
+++ b/tensorflow/compiler/xla/client/BUILD
@@ -65,7 +65,7 @@ cc_library(
deps = [
":global_data",
"//tensorflow/compiler/xla:execution_options_util",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:service_interface",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc
index 3d596a6e65..3a157c69cd 100644
--- a/tensorflow/compiler/xla/client/client.cc
+++ b/tensorflow/compiler/xla/client/client.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h
index 68f0d0ac78..69d4d300ca 100644
--- a/tensorflow/compiler/xla/client/client.h
+++ b/tensorflow/compiler/xla/client/client.h
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service_interface.h"
#include "tensorflow/compiler/xla/statusor.h"
diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD
index a6b9b47253..6933e9a838 100644
--- a/tensorflow/compiler/xla/client/lib/BUILD
+++ b/tensorflow/compiler/xla/client/lib/BUILD
@@ -82,6 +82,7 @@ xla_test(
tags = ["enable_for_xla_interpreter"],
deps = [
":math",
+ "//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
@@ -123,7 +124,7 @@ cc_library(
hdrs = ["testing.h"],
deps = [
"//tensorflow/compiler/xla:execution_options_util",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
diff --git a/tensorflow/compiler/xla/client/lib/constants.cc b/tensorflow/compiler/xla/client/lib/constants.cc
index 1686389a23..031d62e4ff 100644
--- a/tensorflow/compiler/xla/client/lib/constants.cc
+++ b/tensorflow/compiler/xla/client/lib/constants.cc
@@ -21,7 +21,7 @@ limitations under the License.
namespace xla {
XlaOp Zero(XlaBuilder* builder, PrimitiveType type) {
- return ConstantLiteral(builder, Literal::Zero(type));
+ return ConstantLiteral(builder, LiteralUtil::Zero(type));
}
XlaOp Zeros(XlaBuilder* builder, const Shape& shape) {
@@ -38,7 +38,7 @@ XlaOp ZerosLike(XlaOp prototype) {
}
XlaOp One(XlaBuilder* builder, PrimitiveType type) {
- return ConstantLiteral(builder, Literal::One(type));
+ return ConstantLiteral(builder, LiteralUtil::One(type));
}
XlaOp Epsilon(XlaBuilder* builder, PrimitiveType type) {
@@ -61,7 +61,7 @@ XlaOp Epsilon(XlaBuilder* builder, PrimitiveType type) {
}
XlaOp MinValue(XlaBuilder* builder, PrimitiveType type) {
- return ConstantLiteral(builder, Literal::MinValue(type));
+ return ConstantLiteral(builder, LiteralUtil::MinValue(type));
}
XlaOp MinFiniteValue(XlaBuilder* builder, PrimitiveType type) {
@@ -81,7 +81,7 @@ XlaOp MinFiniteValue(XlaBuilder* builder, PrimitiveType type) {
}
XlaOp MaxValue(XlaBuilder* builder, PrimitiveType type) {
- return ConstantLiteral(builder, Literal::MaxValue(type));
+ return ConstantLiteral(builder, LiteralUtil::MaxValue(type));
}
XlaOp MaxFiniteValue(XlaBuilder* builder, PrimitiveType type) {
diff --git a/tensorflow/compiler/xla/client/lib/math_test.cc b/tensorflow/compiler/xla/client/lib/math_test.cc
index 1df4e6ea42..068cd2e586 100644
--- a/tensorflow/compiler/xla/client/lib/math_test.cc
+++ b/tensorflow/compiler/xla/client/lib/math_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
@@ -31,7 +32,7 @@ class MathTest : public ClientLibraryTestBase {
XLA_TEST_F(MathTest, SqrtF32) {
XlaBuilder builder(TestName());
- Literal zero_literal = Literal::Zero(PrimitiveType::F32);
+ Literal zero_literal = LiteralUtil::Zero(PrimitiveType::F32);
std::unique_ptr<GlobalData> zero_data =
client_->TransferToServer(zero_literal).ConsumeValueOrDie();
diff --git a/tensorflow/compiler/xla/client/lib/numeric.cc b/tensorflow/compiler/xla/client/lib/numeric.cc
index cbe9e7fdd1..fd4e8fc390 100644
--- a/tensorflow/compiler/xla/client/lib/numeric.cc
+++ b/tensorflow/compiler/xla/client/lib/numeric.cc
@@ -68,4 +68,12 @@ XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size) {
}
}
+XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m,
+ int64 n) {
+ auto a = Iota(builder, type, m);
+ auto b = Iota(builder, type, n);
+ auto indicator = Eq(a, Broadcast(b, {m}), /*broadcast_dimensions=*/{0});
+ return ConvertElementType(indicator, type);
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/client/lib/numeric.h b/tensorflow/compiler/xla/client/lib/numeric.h
index 2a409ae311..79707007b2 100644
--- a/tensorflow/compiler/xla/client/lib/numeric.h
+++ b/tensorflow/compiler/xla/client/lib/numeric.h
@@ -25,6 +25,10 @@ namespace xla {
// Returns a rank 1 tensor of `type` containing values [0, 1, 2, ...].
XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size);
+// Returns an m x n matrix with 1s on the diagonal elements, zeros everywhere
+// else.
+XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, int64 n);
+
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_NUMERIC_H_
diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc
index 731ad13b8d..534c509868 100644
--- a/tensorflow/compiler/xla/client/lib/testing.cc
+++ b/tensorflow/compiler/xla/client/lib/testing.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
@@ -49,7 +49,7 @@ int64 DataSizeOfShape(const Shape& shape) {
XlaOp BuildFakeDataOpOnDevice(const Shape& shape, XlaBuilder* builder) {
if (ShapeUtil::IsArray(shape)) {
return Broadcast(
- ConstantLiteral(builder, Literal::One(shape.element_type())),
+ ConstantLiteral(builder, LiteralUtil::One(shape.element_type())),
AsInt64Slice(shape.dimensions()));
}
std::vector<XlaOp> parts;
diff --git a/tensorflow/compiler/xla/client/xla_client/BUILD b/tensorflow/compiler/xla/client/xla_client/BUILD
index ee00a9eada..763653c685 100644
--- a/tensorflow/compiler/xla/client/xla_client/BUILD
+++ b/tensorflow/compiler/xla/client/xla_client/BUILD
@@ -43,6 +43,7 @@ cc_library(
deps = [
":xla_computation",
"//tensorflow/compiler/xla:execution_options_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@@ -64,7 +65,7 @@ tf_cc_test(
srcs = ["xla_builder_test.cc"],
deps = [
":xla_builder",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:test",
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
index 12efcb4b4f..aac7df4383 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
@@ -736,7 +736,7 @@ void XlaBuilder::Trace(const string& tag, const XlaOp& operand) {
ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
*instr.mutable_shape() = ShapeUtil::MakeNil();
- *instr.mutable_literal() = Literal::CreateR1U8(tag)->ToProto();
+ *instr.mutable_literal() = LiteralUtil::CreateR1U8(tag)->ToProto();
return AddInstruction(std::move(instr), HloOpcode::kTrace, {operand});
});
}
@@ -1117,6 +1117,35 @@ XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) {
});
}
+XlaOp XlaBuilder::InfeedWithToken(const XlaOp& token, const Shape& shape,
+ const string& config) {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+ if (!LayoutUtil::HasLayout(shape)) {
+ return InvalidArgument("Given shape to Infeed must have a layout");
+ }
+ const Shape infeed_instruction_shape =
+ ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()});
+ *instr.mutable_shape() = infeed_instruction_shape;
+ instr.set_infeed_config(config);
+
+ if (ShapeUtil::IsArray(shape) && sharding() &&
+ sharding()->type() == OpSharding::Type::OpSharding_Type_OTHER) {
+ // TODO(b/110793772): Support tiled array-shaped infeeds.
+ return InvalidArgument(
+ "Tiled sharding is not yet supported for array-shaped infeeds");
+ }
+
+ if (sharding() &&
+ sharding()->type() == OpSharding::Type::OpSharding_Type_REPLICATED) {
+ return InvalidArgument(
+ "Replicated sharding is not yet supported for infeeds");
+ }
+
+ return AddInstruction(std::move(instr), HloOpcode::kInfeed, {token});
+ });
+}
+
void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
const string& outfeed_config) {
ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
@@ -1162,6 +1191,53 @@ void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
});
}
+XlaOp XlaBuilder::OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
+ const Shape& shape_with_layout,
+ const string& outfeed_config) {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+
+ *instr.mutable_shape() = ShapeUtil::MakeTokenShape();
+
+ // Check and set outfeed shape.
+ if (!LayoutUtil::HasLayout(shape_with_layout)) {
+ return InvalidArgument("Given shape to Outfeed must have a layout");
+ }
+ TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
+ if (!ShapeUtil::Compatible(operand_shape, shape_with_layout)) {
+ return InvalidArgument(
+ "Outfeed shape %s must be compatible with operand shape %s",
+ ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str(),
+ ShapeUtil::HumanStringWithLayout(operand_shape).c_str());
+ }
+ *instr.mutable_outfeed_shape() = shape_with_layout;
+
+ instr.set_outfeed_config(outfeed_config);
+
+ return AddInstruction(std::move(instr), HloOpcode::kOutfeed,
+ {operand, token});
+ });
+}
+
+XlaOp XlaBuilder::CreateToken() {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+ *instr.mutable_shape() = ShapeUtil::MakeTokenShape();
+ return AddInstruction(std::move(instr), HloOpcode::kAfterAll);
+ });
+}
+
+XlaOp XlaBuilder::AfterAll(tensorflow::gtl::ArraySlice<XlaOp> tokens) {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ if (tokens.empty()) {
+ return InvalidArgument("AfterAll requires at least one operand");
+ }
+ HloInstructionProto instr;
+ *instr.mutable_shape() = ShapeUtil::MakeTokenShape();
+ return AddInstruction(std::move(instr), HloOpcode::kAfterAll, tokens);
+ });
+}
+
XlaOp XlaBuilder::CustomCall(const string& call_target_name,
tensorflow::gtl::ArraySlice<XlaOp> operands,
const Shape& shape) {
@@ -1365,7 +1441,8 @@ XlaOp XlaBuilder::Rev(const XlaOp& operand,
});
}
-XlaOp XlaBuilder::Sort(XlaOp keys, tensorflow::gtl::optional<XlaOp> values) {
+XlaOp XlaBuilder::Sort(XlaOp keys, tensorflow::gtl::optional<XlaOp> values,
+ int64 dimension) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
std::vector<const Shape*> operand_shape_ptrs;
@@ -1379,6 +1456,11 @@ XlaOp XlaBuilder::Sort(XlaOp keys, tensorflow::gtl::optional<XlaOp> values) {
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
ShapeInference::InferVariadicOpShape(
HloOpcode::kSort, operand_shape_ptrs));
+ if (dimension == -1) {
+ TF_ASSIGN_OR_RETURN(const Shape& keys_shape, GetShape(keys));
+ dimension = ShapeUtil::Rank(keys_shape) - 1;
+ }
+ instr.add_dimensions(dimension);
return values.has_value()
? AddInstruction(std::move(instr), HloOpcode::kSort,
{keys, *values})
@@ -1877,6 +1959,28 @@ void XlaBuilder::Send(const XlaOp& operand, const ChannelHandle& handle) {
});
}
+XlaOp XlaBuilder::SendWithToken(const XlaOp& operand, const XlaOp& token,
+ const ChannelHandle& handle) {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ // Send instruction produces a tuple of {aliased operand, U32 context,
+ // token}.
+ HloInstructionProto send_instr;
+ TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand));
+ *send_instr.mutable_shape() = ShapeUtil::MakeTupleShape(
+ {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()});
+ send_instr.set_channel_id(handle.handle());
+ TF_ASSIGN_OR_RETURN(XlaOp send,
+ AddInstruction(std::move(send_instr), HloOpcode::kSend,
+ {operand, token}));
+
+ HloInstructionProto send_done_instr;
+ *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape();
+ send_done_instr.set_channel_id(handle.handle());
+ return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone,
+ {send});
+ });
+}
+
XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
// Recv HLO takes a single token operand. Generate the token to pass into
@@ -1917,6 +2021,27 @@ XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) {
});
}
+XlaOp XlaBuilder::RecvWithToken(const XlaOp& token, const Shape& shape,
+ const ChannelHandle& handle) {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ // Recv instruction produces a tuple of {receive buffer, U32 context,
+ // token}.
+ HloInstructionProto recv_instr;
+ *recv_instr.mutable_shape() = ShapeUtil::MakeTupleShape(
+ {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()});
+ recv_instr.set_channel_id(handle.handle());
+ TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr),
+ HloOpcode::kRecv, {token}));
+
+ HloInstructionProto recv_done_instr;
+ *recv_done_instr.mutable_shape() =
+ ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()});
+ recv_done_instr.set_channel_id(handle.handle());
+ return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone,
+ {recv});
+ });
+}
+
StatusOr<bool> XlaBuilder::IsConstant(const XlaOp& operand) const {
TF_RETURN_IF_ERROR(first_error_);
@@ -2565,8 +2690,9 @@ XlaOp Rev(const XlaOp& operand, tensorflow::gtl::ArraySlice<int64> dimensions) {
return operand.builder()->Rev(operand, dimensions);
}
-XlaOp Sort(XlaOp keys, tensorflow::gtl::optional<XlaOp> values) {
- return keys.builder()->Sort(keys, std::move(values));
+XlaOp Sort(XlaOp keys, tensorflow::gtl::optional<XlaOp> values,
+ int64 dimension) {
+ return keys.builder()->Sort(keys, std::move(values), dimension);
}
XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max) {
@@ -2624,6 +2750,34 @@ XlaOp Recv(XlaBuilder* builder, const Shape& shape,
return builder->Recv(shape, handle);
}
+XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token,
+ const ChannelHandle& handle) {
+ return operand.builder()->SendWithToken(operand, token, handle);
+}
+
+XlaOp RecvWithToken(const XlaOp& token, const Shape& shape,
+ const ChannelHandle& handle) {
+ return token.builder()->RecvWithToken(token, shape, handle);
+}
+
+XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape,
+ const string& config) {
+ return token.builder()->InfeedWithToken(token, shape, config);
+}
+
+XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
+ const Shape& shape_with_layout,
+ const string& outfeed_config) {
+ return operand.builder()->OutfeedWithToken(operand, token, shape_with_layout,
+ outfeed_config);
+}
+
+XlaOp CreateToken(XlaBuilder* builder) { return builder->CreateToken(); }
+
+XlaOp AfterAll(XlaBuilder* builder, tensorflow::gtl::ArraySlice<XlaOp> tokens) {
+ return builder->AfterAll(tokens);
+}
+
XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale,
const XlaOp& offset, float epsilon,
int64 feature_index) {
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h
index 274aba8a31..2be6f4a553 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.h
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/padding.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -532,6 +533,8 @@ class XlaBuilder {
// Enqueues an infeed instruction onto the computation, which writes data of
// the given shape to the infeed buffer of the device.
XlaOp Infeed(const Shape& shape, const string& config = "");
+ XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape,
+ const string& config = "");
// Enqueues an outfeed instruction onto the computation. This instruction
// generates outgoing data transfers for the given data.
@@ -541,6 +544,9 @@ class XlaBuilder {
// will occur.
void Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
const string& outfeed_config);
+ XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
+ const Shape& shape_with_layout,
+ const string& outfeed_config);
// Enqueues a call instruction onto the computation.
XlaOp Call(const XlaComputation& computation,
@@ -788,17 +794,23 @@ class XlaBuilder {
// Enqueues a sort (as increasing order) instruction onto the computation.
// If only keys are provided:
- // * The keys must be a rank-1 tensor (i.e. an array).
- // * The result is a sorted array of keys.
+ // * If the keys are an rank-1 tensor (an array), the result is a sorted array
+ // of keys, in ascending order.
+ // * If the keys have higher rank, the keys are sorted along the provided
+ // dimension. For example, for a rank-2 tensor (a matrix) of keys, a dimension
+ // value of 0 will indepenently sort every column, and a dimension value of 1
+ // will independently sort each row. If no dimension number is provided, then
+ // the last dimension is chosen by default.
//
// If both keys and values are provided:
- // * The keys and the values must be rank-1 tensors with the same dimensions.
- // The element types of the tensors may be different.
- // * The result is a tuple that consists of a sorted array of keys as the
- // first element, and an array with their corresponding values as the second
- // element.
- XlaOp Sort(XlaOp keys, tensorflow::gtl::optional<XlaOp> values =
- tensorflow::gtl::nullopt);
+ // * The keys and the values must tensors with the same dimensions. The
+ // element types of the tensors may be different.
+ // * The result is a tuple that consists of a sorted tensor of keys (along the
+ // provided dimension, as above) as the first element, and a tensor with their
+ // corresponding values as the second element.
+ XlaOp Sort(XlaOp keys,
+ tensorflow::gtl::optional<XlaOp> values = tensorflow::gtl::nullopt,
+ int64 dimension = -1);
// Enqueues a clamp instruction onto the computation.
XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);
@@ -839,11 +851,23 @@ class XlaBuilder {
// Enqueues a Send node onto the computation, to send the given operand to
// a Recv instruction that shares the same channel handle.
void Send(const XlaOp& operand, const ChannelHandle& handle);
+ XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token,
+ const ChannelHandle& handle);
+
+ // Enqueues an AfterAll operation with no operands producing a token-shaped
+ // value.
+ XlaOp CreateToken();
+
+ // Enqueues an AfterAll operation with no operands producing a token-shaped
+ // value.
+ XlaOp AfterAll(tensorflow::gtl::ArraySlice<XlaOp> tokens);
// Enqueues a Recv node onto the computation. The data comes from a Send
// instruction that shares the same channel handle and its shape must
// be the same as the given shape.
XlaOp Recv(const Shape& shape, const ChannelHandle& handle);
+ XlaOp RecvWithToken(const XlaOp& token, const Shape& shape,
+ const ChannelHandle& handle);
// Normalizes operand across spatial and batch dimensions for each feature.
//
@@ -1229,7 +1253,8 @@ class XlaBuilder {
tensorflow::gtl::ArraySlice<int64> permutation);
friend XlaOp Rev(const XlaOp& operand,
tensorflow::gtl::ArraySlice<int64> dimensions);
- friend XlaOp Sort(XlaOp keys, tensorflow::gtl::optional<XlaOp> values);
+ friend XlaOp Sort(XlaOp keys, tensorflow::gtl::optional<XlaOp> values,
+ int64 dimension);
friend XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);
friend XlaOp Map(XlaBuilder* builder,
tensorflow::gtl::ArraySlice<XlaOp> operands,
@@ -1264,6 +1289,18 @@ class XlaBuilder {
const XlaOp& batch_mean, const XlaOp& batch_var,
const XlaOp& grad_output, float epsilon,
int64 feature_index);
+ friend XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token,
+ const ChannelHandle& handle);
+ friend XlaOp RecvWithToken(const XlaOp& token, const Shape& shape,
+ const ChannelHandle& handle);
+ friend XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape,
+ const string& config);
+ friend XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
+ const Shape& shape_with_layout,
+ const string& outfeed_config);
+ friend XlaOp CreateToken(XlaBuilder* builder);
+ friend XlaOp AfterAll(XlaBuilder* builder,
+ tensorflow::gtl::ArraySlice<XlaOp> tokens);
};
// RAII-style object: sets the current sharding assignment in builder on
@@ -1595,6 +1632,13 @@ XlaOp Fft(const XlaOp& operand, FftType fft_type,
XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
const string& config = "");
+// Variant of Infeed which takes a token-shaped operand and produces a
+// two-element tuple containing the data value and a token-shaped value.
+// Tokens are used for ordering side-effecting operations.
+// TODO(b/110532604): Replace all uses of the non-token form with this variant.
+XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape,
+ const string& config = "");
+
// Enqueues an outfeed instruction onto the computation. This instruction
// generates outgoing data transfers for the given data.
//
@@ -1604,6 +1648,13 @@ XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
void Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
const string& outfeed_config);
+// Variant of Outfeed which takes a token-shaped operand and produces a
+// token-shaped value. Tokens are used for ordering side-effecting operations.
+// TODO(b/110532604): Replace all uses of the non-token form with this variant.
+XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
+ const Shape& shape_with_layout,
+ const string& outfeed_config);
+
// Enqueues a call instruction onto the computation.
XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
tensorflow::gtl::ArraySlice<XlaOp> operands);
@@ -1844,16 +1895,25 @@ XlaOp Transpose(const XlaOp& operand,
// is moved to index dimension_size - 1 - i).
XlaOp Rev(const XlaOp& operand, tensorflow::gtl::ArraySlice<int64> dimensions);
-// * The result is a sorted array of keys.
+// Enqueues a sort (as increasing order) instruction onto the computation.
+// If only keys are provided:
+// * If the keys are an rank-1 tensor (an array), the result is a sorted array
+// of keys, in ascending order.
+// * If the keys have higher rank, the keys are sorted along the provided
+// dimension. For example, for a rank-2 tensor (a matrix) of keys, a dimension
+// value of 0 will indepenently sort every column, and a dimension value of 1
+// will independently sort each row. If no dimension number is provided, then
+// the last dimension is chosen by default.
//
// If both keys and values are provided:
-// * The keys and the values must be rank-1 tensors with the same dimensions.
-// The element types of the tensors may be different.
-// * The result is a tuple that consists of a sorted array of keys as the
-// first element, and an array with their corresponding values as the second
-// element.
+// * The keys and the values must tensors with the same dimensions. The
+// element types of the tensors may be different.
+// * The result is a tuple that consists of a sorted tensor of keys (along the
+// provided dimension, as above) as the first element, and a tensor with their
+// corresponding values as the second element.
XlaOp Sort(XlaOp keys,
- tensorflow::gtl::optional<XlaOp> values = tensorflow::gtl::nullopt);
+ tensorflow::gtl::optional<XlaOp> values = tensorflow::gtl::nullopt,
+ int64 dimension = -1);
// Enqueues a clamp instruction onto the computation.
XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);
@@ -1895,12 +1955,38 @@ XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices,
// a Recv instruction that shares the same channel handle.
void Send(const XlaOp& operand, const ChannelHandle& handle);
+// Variant of Send which takes a token-shaped operand and produces a
+// token-shaped value. Tokens are used for ordering side-effecting operations.
+// TODO(b/110532604): Replace all uses of the non-token form with this variant.
+XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token,
+ const ChannelHandle& handle);
+
// Enqueues a Recv node onto the computation. The data comes from a Send
// instruction that shares the same channel handle and its shape must
// be the same as the given shape.
XlaOp Recv(XlaBuilder* builder, const Shape& shape,
const ChannelHandle& handle);
+// Variant of Recv which takes a token-shaped operand and produces a two-element
+// tuple containing the data value and a token-shaped value. Tokens are used
+// for ordering side-effecting operations.
+// TODO(b/110532604): Replace all uses of the non-token form with this variant.
+XlaOp RecvWithToken(const XlaOp& token, const Shape& shape,
+ const ChannelHandle& handle);
+
+// Enqueues an operation (AfterAll) with no operands that produces a
+// token-shaped value. Tokens are used for ordering side-effecting operations.
+// This is a separate method from AfterAll to facility the removal of
+// operand-less AfterAll instructions.
+// TODO(b/110532604): Remove this function when all tokens are derived from a
+// single token generated or passed into the entry computation.
+XlaOp CreateToken(XlaBuilder* builder);
+
+// Enqueues an AfterAll instruction which produces a token-shaped value and
+// takes a variadic number of token-shaped operands. The number of operands must
+// be greater than zero. Used for joining tokens.
+XlaOp AfterAll(XlaBuilder* builder, tensorflow::gtl::ArraySlice<XlaOp> tokens);
+
// Normalizes operand across spatial and batch dimensions for each feature.
//
// Returns a tuple (normalized, batch_mean, batch_var) where `normalized`
@@ -1943,12 +2029,12 @@ XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
template <typename NativeT>
XlaOp XlaBuilder::ConstantR0(NativeT value) {
- return ConstantLiteral(*Literal::CreateR0<NativeT>(value));
+ return ConstantLiteral(*LiteralUtil::CreateR0<NativeT>(value));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantR1(tensorflow::gtl::ArraySlice<NativeT> values) {
- return ConstantLiteral(*Literal::CreateR1<NativeT>(values));
+ return ConstantLiteral(*LiteralUtil::CreateR1<NativeT>(values));
}
template <typename NativeT>
@@ -1960,44 +2046,44 @@ XlaOp XlaBuilder::ConstantR1(int64 length, NativeT value) {
}
inline XlaOp XlaBuilder::ConstantR1(const tensorflow::core::Bitmap& values) {
- return ConstantLiteral(*Literal::CreateR1(values));
+ return ConstantLiteral(*LiteralUtil::CreateR1(values));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantR2(
std::initializer_list<std::initializer_list<NativeT>> values) {
- return ConstantLiteral(*Literal::CreateR2<NativeT>(values));
+ return ConstantLiteral(*LiteralUtil::CreateR2<NativeT>(values));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantFromArrayWithLayout(const Array<NativeT>& values,
const Layout& layout) {
return ConstantLiteral(
- *Literal::CreateFromArrayWithLayout<NativeT>(values, layout));
+ *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantFromArray(const Array<NativeT>& values) {
- return ConstantLiteral(*Literal::CreateFromArray<NativeT>(values));
+ return ConstantLiteral(*LiteralUtil::CreateFromArray<NativeT>(values));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantR2FromArray2DWithLayout(
const Array2D<NativeT>& values, const Layout& layout) {
return ConstantLiteral(
- *Literal::CreateFromArrayWithLayout<NativeT>(values, layout));
+ *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantR2FromArray2D(const Array2D<NativeT>& values) {
- return ConstantLiteral(*Literal::CreateR2FromArray2D<NativeT>(values));
+ return ConstantLiteral(*LiteralUtil::CreateR2FromArray2D<NativeT>(values));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantR3FromArray3DWithLayout(
const Array3D<NativeT>& values, const Layout& layout) {
return ConstantLiteral(
- *Literal::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
+ *LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
@@ -2020,13 +2106,13 @@ XlaOp XlaBuilder::ConstantR4FromArray4D(const Array4D<NativeT>& values) {
template <typename NativeT>
XlaOp ConstantR0(XlaBuilder* builder, NativeT value) {
- return ConstantLiteral(builder, *Literal::CreateR0<NativeT>(value));
+ return ConstantLiteral(builder, *LiteralUtil::CreateR0<NativeT>(value));
}
template <typename NativeT>
XlaOp ConstantR1(XlaBuilder* builder,
tensorflow::gtl::ArraySlice<NativeT> values) {
- return ConstantLiteral(builder, *Literal::CreateR1<NativeT>(values));
+ return ConstantLiteral(builder, *LiteralUtil::CreateR1<NativeT>(values));
}
template <typename NativeT>
@@ -2039,13 +2125,13 @@ XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value) {
inline XlaOp ConstantR1(XlaBuilder* builder,
const tensorflow::core::Bitmap& values) {
- return ConstantLiteral(builder, *Literal::CreateR1(values));
+ return ConstantLiteral(builder, *LiteralUtil::CreateR1(values));
}
template <typename NativeT>
XlaOp ConstantR2(XlaBuilder* builder,
std::initializer_list<std::initializer_list<NativeT>> values) {
- return ConstantLiteral(builder, *Literal::CreateR2<NativeT>(values));
+ return ConstantLiteral(builder, *LiteralUtil::CreateR2<NativeT>(values));
}
template <typename NativeT>
@@ -2053,12 +2139,14 @@ XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
const Array<NativeT>& values,
const Layout& layout) {
return ConstantLiteral(
- builder, *Literal::CreateFromArrayWithLayout<NativeT>(values, layout));
+ builder,
+ *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
XlaOp ConstantFromArray(XlaBuilder* builder, const Array<NativeT>& values) {
- return ConstantLiteral(builder, *Literal::CreateFromArray<NativeT>(values));
+ return ConstantLiteral(builder,
+ *LiteralUtil::CreateFromArray<NativeT>(values));
}
template <typename NativeT>
@@ -2066,14 +2154,15 @@ XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
const Array2D<NativeT>& values,
const Layout& layout) {
return ConstantLiteral(
- builder, *Literal::CreateFromArrayWithLayout<NativeT>(values, layout));
+ builder,
+ *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
const Array2D<NativeT>& values) {
return ConstantLiteral(builder,
- *Literal::CreateR2FromArray2D<NativeT>(values));
+ *LiteralUtil::CreateR2FromArray2D<NativeT>(values));
}
template <typename NativeT>
@@ -2082,7 +2171,7 @@ XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
const Layout& layout) {
return ConstantLiteral(
builder,
- *Literal::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
+ *LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc
new file mode 100644
index 0000000000..5db124b5a2
--- /dev/null
+++ b/tensorflow/compiler/xla/literal.cc
@@ -0,0 +1,1967 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/literal.h"
+
+#include <algorithm>
+#include <cstring>
+#include <functional>
+#include <limits>
+#include <numeric>
+#include <vector>
+
+#include "tensorflow/compiler/xla/index_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/core/casts.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+using tensorflow::strings::Printf;
+using tensorflow::strings::StrCat;
+
+namespace xla {
+
+namespace {
+
+constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__;
+
+// Converts between little and big endian.
+//
+// Precondition: size % 2 == 0 (elements in the array are 16 bits long)
+void ConvertEndianShort(string* bytes) {
+ CHECK_EQ(bytes->size() / 2, 0);
+ for (int64 i = 0; i < bytes->size(); i += 2) {
+ std::swap((*bytes)[i], (*bytes)[i + 1]);
+ }
+}
+
+void ConvertEndianShort(char* bytes, int64 size) {
+ CHECK_EQ(size / 2, 0);
+ for (int64 i = 0; i < size; i += 2) {
+ std::swap(bytes[i], bytes[i + 1]);
+ }
+}
+
+} // namespace
+
+LiteralBase::~LiteralBase() {}
+
+std::ostream& operator<<(std::ostream& out, const Literal& literal) {
+ out << literal.ToString();
+ return out;
+}
+
+Literal::StrideConfig::StrideConfig(
+ const Shape& source_shape, const Shape& dest_shape,
+ tensorflow::gtl::ArraySlice<int64> dimensions)
+ : dimensions(dimensions),
+ base(dimensions.size(), 0),
+ step(dimensions.size(), 1) {
+ if (!dimensions.empty()) {
+ // Selects the shape with the largest minor dimension as the one upon
+ // which to run the tight stride loop.
+ if (dimensions[LayoutUtil::Minor(source_shape.layout(), 0)] >=
+ dimensions[LayoutUtil::Minor(dest_shape.layout(), 0)]) {
+ minor_dimension = LayoutUtil::Minor(source_shape.layout(), 0);
+ dest_stride = IndexUtil::GetDimensionStride(dest_shape, minor_dimension);
+ } else {
+ minor_dimension = LayoutUtil::Minor(dest_shape.layout(), 0);
+ source_stride =
+ IndexUtil::GetDimensionStride(source_shape, minor_dimension);
+ }
+ minor_loop_size = dimensions[minor_dimension];
+ step[minor_dimension] = minor_loop_size;
+ }
+}
+
+Literal::Literal(const Shape& shape)
+ : Literal(shape, /*allocate_arrays=*/true) {}
+
+void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) {
+ if (ShapeUtil::IsTuple(shape)) {
+ for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
+ const Shape& subshape = shape.tuple_shapes(i);
+
+ auto child_piece = Piece();
+ child_piece.set_subshape(&subshape);
+
+ SetPiece(subshape, &child_piece, allocate_arrays);
+
+ piece->emplace_back(std::move(child_piece));
+ }
+ } else if (ShapeUtil::IsArray(shape)) {
+ if (allocate_arrays) {
+ if (LayoutUtil::IsSparseArray(shape)) {
+ // For sparse arrays, the buffer must be of the size of the maximum
+ // number of sparse elements possible.
+ const int64 max_sparse_elements =
+ LayoutUtil::MaxSparseElements(shape.layout());
+ piece->set_buffer(
+ new char[max_sparse_elements *
+ ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type())]);
+ piece->set_sparse_indices(
+ new SparseIndexArray(max_sparse_elements, ShapeUtil::Rank(shape)));
+ } else {
+ piece->set_buffer(new char[piece->size_bytes()]);
+ }
+ }
+ } else {
+ // If the shape is neither an array nor tuple, then it must be
+ // zero-sized. Otherwise, some memory needs to be allocated for it.
+ CHECK_EQ(piece->size_bytes(), 0);
+ }
+}
+
+Literal::Literal(const Shape& shape, bool allocate_arrays)
+ : LiteralBase(), shape_(MakeUnique<Shape>(shape)) {
+ CHECK(LayoutUtil::HasLayout(*shape_));
+ root_piece_ = new Piece();
+ root_piece_->set_subshape(shape_.get());
+ CHECK(&root_piece_->subshape() == shape_.get());
+
+ SetPiece(*shape_, root_piece_, allocate_arrays);
+}
+
+Literal::~Literal() {
+ if (root_piece_ != nullptr) {
+ DeallocateBuffers();
+ delete root_piece_;
+ }
+}
+
+void Literal::DeallocateBuffers() {
+ root_piece_->ForEachMutableSubpiece(
+ [&](const ShapeIndex& index, Piece* piece) {
+ if (piece->buffer() != nullptr) {
+ delete[] piece->buffer();
+ delete piece->sparse_indices();
+ }
+ });
+}
+
+Literal::Literal(Literal&& other) : LiteralBase() { *this = std::move(other); }
+
+Literal& Literal::operator=(Literal&& other) {
+ DCHECK(&other.root_piece_->subshape() == other.shape_.get());
+ using std::swap;
+ swap(shape_, other.shape_);
+ swap(root_piece_, other.root_piece_);
+ DCHECK(&root_piece_->subshape() == shape_.get());
+
+ return *this;
+}
+
+std::unique_ptr<Literal> LiteralBase::CreateFromShape(const Shape& shape) {
+ auto literal = MakeUnique<Literal>(shape);
+ literal->root_piece_->ForEachMutableSubpiece(
+ [&](const ShapeIndex& index, Piece* piece) {
+ if (ShapeUtil::IsArray(piece->subshape())) {
+ memset(piece->untyped_data(), 0, piece->size_bytes());
+ }
+ });
+ return literal;
+}
+
+const SparseIndexArray* LiteralBase::sparse_indices(
+ const ShapeIndex& shape_index) const {
+ return piece(shape_index).sparse_indices();
+}
+
+SparseIndexArray* Literal::sparse_indices(const ShapeIndex& shape_index) {
+ return piece(shape_index).sparse_indices();
+}
+
+template <typename NativeT>
+Status Literal::CopySliceFromInternal(
+ const LiteralBase& src_literal, tensorflow::gtl::ArraySlice<int64> src_base,
+ tensorflow::gtl::ArraySlice<int64> dest_base,
+ tensorflow::gtl::ArraySlice<int64> copy_size) {
+ TF_RET_CHECK(ShapeUtil::Rank(src_literal.shape()) == src_base.size());
+ TF_RET_CHECK(ShapeUtil::Rank(shape()) == dest_base.size());
+
+ auto linear_index = [](const Shape& shape,
+ tensorflow::gtl::ArraySlice<int64> multi_index) {
+ return IndexUtil::MultidimensionalIndexToLinearIndex(shape, multi_index);
+ };
+
+ if (ShapeUtil::Rank(src_literal.shape()) == 0 ||
+ ShapeUtil::Rank(shape()) == 0) {
+ // If any of the two shapes are scalars, we can just call the StridedCopy()
+ // directly, and we know we will be copying only one value.
+ TF_RET_CHECK(copy_size.empty());
+ StridedCopy(data<NativeT>(), linear_index(shape(), dest_base), 0,
+ src_literal.data<NativeT>(),
+ linear_index(src_literal.shape(), src_base), 0, 1);
+ } else if (!ShapeUtil::IsZeroElementArray(shape()) &&
+ !ShapeUtil::IsZeroElementArray(src_literal.shape())) {
+ // Perform copy if neither src nor dest has dimensions with zero element,
+ // otherwise it's a no-op.
+ TF_RET_CHECK(src_base.size() == dest_base.size());
+ TF_RET_CHECK(src_base.size() == copy_size.size());
+
+ // Scan the source from minor, stepping in copy size blocks, then within
+ // the index enumaration functor, do a strided copy advancing source index
+ // by one (walking through the minor dimension), and destination index by
+ // proper stride size at the matching dimension.
+ DimensionVector src_indexes(src_base.size(), 0);
+ DimensionVector dest_indexes(dest_base.size(), 0);
+ Literal::StrideConfig stride_config(src_literal.shape(), shape(),
+ copy_size);
+
+ auto copy_proc = [&](tensorflow::gtl::ArraySlice<int64> indexes) {
+ // Map from multi-dimensional index, to source index.
+ std::transform(indexes.begin(), indexes.end(), src_base.begin(),
+ src_indexes.begin(), std::plus<int64>());
+ // Map from multi-dimensional index, to destination index.
+ std::transform(indexes.begin(), indexes.end(), dest_base.begin(),
+ dest_indexes.begin(), std::plus<int64>());
+
+ int64 src_index = linear_index(src_literal.shape(), src_indexes);
+ int64 dest_index = linear_index(shape(), dest_indexes);
+
+ // `this->` is needed to workaround MSVC bug: #16882
+ StridedCopy(this->data<NativeT>(), dest_index, stride_config.dest_stride,
+ src_literal.data<NativeT>(), src_index,
+ stride_config.source_stride, stride_config.minor_loop_size);
+ return true;
+ };
+
+ ShapeUtil::ForEachIndex(src_literal.shape(), stride_config.base,
+ stride_config.dimensions, stride_config.step,
+ copy_proc);
+ }
+ return Status::OK();
+}
+
+Status Literal::CopyElementFrom(const LiteralSlice& src_literal,
+ tensorflow::gtl::ArraySlice<int64> src_index,
+ tensorflow::gtl::ArraySlice<int64> dest_index) {
+ DCHECK_EQ(shape().element_type(), src_literal.shape().element_type());
+ const int64 src_linear_index = IndexUtil::MultidimensionalIndexToLinearIndex(
+ src_literal.shape(), src_index);
+ const int64 dest_linear_index =
+ IndexUtil::MultidimensionalIndexToLinearIndex(shape(), dest_index);
+ const int64 primitive_size =
+ ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
+
+ char* dest_address =
+ static_cast<char*>(untyped_data()) + dest_linear_index * primitive_size;
+ const char* source_address =
+ static_cast<const char*>(src_literal.untyped_data()) +
+ src_linear_index * primitive_size;
+ if (dest_address != source_address) {
+ memcpy(dest_address, source_address, primitive_size);
+ }
+ return Status::OK();
+}
+
+/* static */ StatusOr<std::unique_ptr<Literal>> Literal::CreateFromProto(
+ const LiteralProto& proto) {
+ if (!proto.has_shape()) {
+ return InvalidArgument("LiteralProto has no shape");
+ }
+ if (!LayoutUtil::HasLayout(proto.shape())) {
+ return InvalidArgument("LiteralProto has no layout");
+ }
+
+ auto literal = MakeUnique<Literal>(proto.shape());
+
+ TF_RETURN_IF_ERROR(literal->root_piece_->ForEachMutableSubpieceWithStatus(
+ [&](const ShapeIndex& index, Piece* piece) {
+ const LiteralProto* proto_element = &proto;
+ for (int64 i : index) {
+ CHECK(i < proto_element->tuple_literals_size());
+ proto_element = &proto_element->tuple_literals(i);
+ }
+
+ if (ShapeUtil::IsTuple(piece->subshape())) {
+ if (proto_element->tuple_literals_size() !=
+ ShapeUtil::TupleElementCount(piece->subshape())) {
+ return InvalidArgument(
+ "Expected %lld tuple elements in LiteralProto, has %d",
+ ShapeUtil::TupleElementCount(piece->subshape()),
+ proto_element->tuple_literals_size());
+ }
+ return Status::OK();
+ }
+ if (piece->subshape().element_type() == TOKEN) {
+ return Status::OK();
+ }
+
+ CHECK(ShapeUtil::IsArray(piece->subshape()));
+ TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element));
+
+ return Status::OK();
+ }));
+
+ return std::move(literal);
+}
+
+std::vector<Literal> Literal::DecomposeTuple() {
+ CHECK(ShapeUtil::IsTuple(shape()));
+ std::vector<Literal> elements;
+ for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) {
+ elements.push_back(Literal(ShapeUtil::GetSubshape(shape(), {i}),
+ /*allocate_arrays=*/false));
+ Literal& element = elements.back();
+ element.root_piece_->ForEachMutableSubpiece(
+ [&](const ShapeIndex& index, Piece* dest_piece) {
+ ShapeIndex src_index = {i};
+ for (int64 j : index) {
+ src_index.push_back(j);
+ }
+ Piece& src_piece = piece(src_index);
+
+ // Move the respective buffer and sparse indices over to the element
+ // Literal.
+ dest_piece->set_buffer(src_piece.buffer());
+ src_piece.set_buffer(nullptr);
+ dest_piece->set_sparse_indices(src_piece.sparse_indices());
+ src_piece.set_sparse_indices(nullptr);
+ });
+ }
+ // Set this literal to be nil-shaped.
+ *this = Literal();
+ return elements;
+}
+
+namespace {
+
+// Copies the elements in 'src' to 'dest'. The shape and layout of the data in
+// the array slices are indicated by dest_shape and src_shape respectively.
+template <typename NativeT>
+void CopyElementsBetween(tensorflow::gtl::MutableArraySlice<NativeT> dest,
+ tensorflow::gtl::ArraySlice<NativeT> src,
+ const Shape& dest_shape, const Shape& src_shape) {
+ CHECK(ShapeUtil::Compatible(dest_shape, src_shape));
+ if (ShapeUtil::IsZeroElementArray(dest_shape)) {
+ return;
+ }
+ std::vector<int64> index(ShapeUtil::Rank(dest_shape));
+ do {
+ dest[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape, index)] =
+ src[IndexUtil::MultidimensionalIndexToLinearIndex(src_shape, index)];
+ } while (IndexUtil::BumpIndices(dest_shape, &index));
+}
+
+} // namespace
+
+Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) {
+ CHECK(subshape_ != nullptr);
+ CHECK(src.subshape_ != nullptr);
+ if (ShapeUtil::Equal(subshape(), src.subshape())) {
+ // If the layouts are equal it's faster just to memcpy.
+ memcpy(buffer(), src.buffer(), src.size_bytes());
+ } else {
+ TF_RET_CHECK(ShapeUtil::Compatible(src.subshape(), subshape()));
+ std::vector<int64> origin(ShapeUtil::Rank(subshape()), 0);
+ switch (subshape().element_type()) {
+#define COPY_ELEMENTS(XLA_T, NATIVE_T) \
+ case (XLA_T): \
+ CopyElementsBetween<NATIVE_T>(data<NATIVE_T>(), src.data<NATIVE_T>(), \
+ subshape(), src.subshape()); \
+ break;
+ COPY_ELEMENTS(U8, uint8);
+ COPY_ELEMENTS(U16, uint16);
+ COPY_ELEMENTS(U32, uint32);
+ COPY_ELEMENTS(U64, uint64);
+ COPY_ELEMENTS(S8, int8);
+ COPY_ELEMENTS(S16, int16);
+ COPY_ELEMENTS(S32, int32);
+ COPY_ELEMENTS(S64, int64);
+ COPY_ELEMENTS(F16, half);
+ COPY_ELEMENTS(BF16, bfloat16);
+ COPY_ELEMENTS(F32, float);
+ COPY_ELEMENTS(F64, double);
+ COPY_ELEMENTS(C64, complex64);
+ COPY_ELEMENTS(PRED, bool);
+#undef COPY_ELEMENTS
+ default:
+ return Unimplemented(
+ "Copying a Literal object with element type %s is not implemented.",
+ PrimitiveType_Name(subshape().element_type()).c_str());
+ }
+ }
+ return Status::OK();
+}
+
+Status Literal::CopyFrom(const LiteralSlice& src_literal,
+ const ShapeIndex& dest_shape_index,
+ const ShapeIndex& src_shape_index) {
+ const Shape& dest_subshape =
+ ShapeUtil::GetSubshape(shape(), dest_shape_index);
+ const Shape& src_subshape =
+ ShapeUtil::GetSubshape(src_literal.shape(), src_shape_index);
+ if (!ShapeUtil::Compatible(dest_subshape, src_subshape)) {
+ return InvalidArgument(
+ "Destination subshape incompatible with source subshape: %s vs %s",
+ ShapeUtil::HumanString(dest_subshape).c_str(),
+ ShapeUtil::HumanString(src_subshape).c_str());
+ }
+ return root_piece_->ForEachMutableSubpieceWithStatus(
+ [&](const ShapeIndex& index, Piece* piece) {
+ if (!ShapeUtil::IsArray(piece->subshape())) {
+ return Status::OK();
+ }
+
+ // Determine if this index is in the part of this literal that we want
+ // to copy over from src_literal.
+ bool in_subtree_to_copy = true;
+ for (int i = 0; i < dest_shape_index.size(); ++i) {
+ if (index[i] != dest_shape_index[i]) {
+ in_subtree_to_copy = false;
+ break;
+ }
+ }
+ if (!in_subtree_to_copy) {
+ return Status::OK();
+ }
+ // Construct the index of the corresponding piece in the source literal.
+ ShapeIndex src_piece_index = src_shape_index;
+ for (int64 i = dest_shape_index.size(); i < index.size(); ++i) {
+ src_piece_index.push_back(index[i]);
+ }
+ TF_RETURN_IF_ERROR(piece->CopyFrom(src_literal.piece(src_piece_index)));
+ return Status::OK();
+ });
+}
+
+Status Literal::MoveFrom(Literal&& src_literal,
+ const ShapeIndex& dest_shape_index) {
+ const Shape& dest_subshape =
+ ShapeUtil::GetSubshape(shape(), dest_shape_index);
+ if (!ShapeUtil::Equal(dest_subshape, src_literal.shape())) {
+ return InvalidArgument(
+ "Destination subshape not equal to source shape: %s vs %s",
+ ShapeUtil::HumanString(dest_subshape).c_str(),
+ ShapeUtil::HumanString(src_literal.shape()).c_str());
+ }
+
+ src_literal.root_piece_->ForEachSubpiece(
+ [&](const ShapeIndex& src_index, const Piece& src_piece) {
+ if (!ShapeUtil::IsArray(src_piece.subshape())) {
+ return;
+ }
+
+ ShapeIndex dest_index = dest_shape_index;
+ for (int64 i : src_index) {
+ dest_index.push_back(i);
+ }
+ Piece& dest_piece = piece(dest_index);
+ delete[] dest_piece.buffer();
+ dest_piece.set_buffer(src_piece.buffer());
+ delete dest_piece.sparse_indices();
+ dest_piece.set_sparse_indices(src_piece.sparse_indices());
+ });
+
+ src_literal.shape_ = MakeUnique<Shape>(ShapeUtil::MakeNil());
+ delete src_literal.root_piece_;
+ src_literal.root_piece_ = new LiteralBase::Piece();
+ src_literal.root_piece_->set_subshape(src_literal.shape_.get());
+
+ return Status::OK();
+}
+
+Status Literal::CopySliceFrom(const LiteralSlice& src_literal,
+ tensorflow::gtl::ArraySlice<int64> src_base,
+ tensorflow::gtl::ArraySlice<int64> dest_base,
+ tensorflow::gtl::ArraySlice<int64> copy_size) {
+ TF_RET_CHECK(ShapeUtil::IsArray(shape())) << ShapeUtil::HumanString(shape());
+ TF_RET_CHECK(ShapeUtil::IsArray(src_literal.shape()))
+ << ShapeUtil::HumanString(src_literal.shape());
+ TF_RET_CHECK(ShapeUtil::SameElementType(src_literal.shape(), shape()));
+
+ switch (shape().element_type()) {
+ case U8:
+ return CopySliceFromInternal<uint8>(src_literal, src_base, dest_base,
+ copy_size);
+ case U16:
+ return CopySliceFromInternal<uint16>(src_literal, src_base, dest_base,
+ copy_size);
+ case U32:
+ return CopySliceFromInternal<uint32>(src_literal, src_base, dest_base,
+ copy_size);
+ case U64:
+ return CopySliceFromInternal<uint64>(src_literal, src_base, dest_base,
+ copy_size);
+ case S8:
+ return CopySliceFromInternal<int8>(src_literal, src_base, dest_base,
+ copy_size);
+ case S16:
+ return CopySliceFromInternal<int16>(src_literal, src_base, dest_base,
+ copy_size);
+ case S32:
+ return CopySliceFromInternal<int32>(src_literal, src_base, dest_base,
+ copy_size);
+ case S64:
+ return CopySliceFromInternal<int64>(src_literal, src_base, dest_base,
+ copy_size);
+ case F16:
+ return CopySliceFromInternal<half>(src_literal, src_base, dest_base,
+ copy_size);
+ case BF16:
+ return CopySliceFromInternal<bfloat16>(src_literal, src_base, dest_base,
+ copy_size);
+ case F32:
+ return CopySliceFromInternal<float>(src_literal, src_base, dest_base,
+ copy_size);
+ case F64:
+ return CopySliceFromInternal<double>(src_literal, src_base, dest_base,
+ copy_size);
+ case C64:
+ return CopySliceFromInternal<complex64>(src_literal, src_base, dest_base,
+ copy_size);
+ case PRED:
+ return CopySliceFromInternal<bool>(src_literal, src_base, dest_base,
+ copy_size);
+ default:
+ break;
+ }
+ return Unimplemented(
+ "Copying a slice from a Literal object with element type %d is not "
+ "implemented.",
+ shape().element_type());
+}
+
+void Literal::PopulateR1(const tensorflow::core::Bitmap& values) {
+ CHECK(ShapeUtil::IsArray(shape()));
+ CHECK_EQ(ShapeUtil::Rank(shape()), 1);
+ CHECK_EQ(element_count(), values.bits());
+ CHECK_EQ(shape().element_type(), PRED);
+ for (int64 i = 0; i < static_cast<int64>(values.bits()); ++i) {
+ Set({i}, values.get(i));
+ }
+}
+
+std::unique_ptr<Literal> LiteralBase::Relayout(
+ const Layout& new_layout, const ShapeIndex& shape_index) const {
+ // Create new shape with 'new_layout' set at the given shape index.
+ Shape new_shape = shape();
+ Shape* subshape = ShapeUtil::GetMutableSubshape(&new_shape, shape_index);
+ TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(new_layout, *subshape));
+ *subshape->mutable_layout() = new_layout;
+ auto result = MakeUnique<Literal>(new_shape);
+ TF_CHECK_OK(result->CopyFrom(*this));
+ return result;
+}
+
+std::unique_ptr<Literal> LiteralBase::Relayout(
+ const Shape& shape_with_layout) const {
+ CHECK(ShapeUtil::Compatible(shape_with_layout, shape()))
+ << "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout)
+ << " not compatible with literal shape "
+ << ShapeUtil::HumanString(shape());
+ std::unique_ptr<Literal> result = CreateFromShape(shape_with_layout);
+ ShapeUtil::ForEachSubshape(
+ result->shape(),
+ [this, &result](const Shape& subshape, const ShapeIndex& index) {
+ if (ShapeUtil::IsArray(subshape)) {
+ TF_CHECK_OK(result->CopyFrom(*this,
+ /*dest_shape_index=*/index,
+ /*src_shape_index=*/index));
+ }
+ });
+ return result;
+}
+
+StatusOr<std::unique_ptr<Literal>> LiteralBase::Broadcast(
+ const Shape& result_shape,
+ tensorflow::gtl::ArraySlice<int64> dimensions) const {
+ if (!ShapeUtil::IsArray(shape())) {
+ return InvalidArgument("Broadcast only supports arrays.");
+ }
+
+ for (int64 i = 0; i < dimensions.size(); i++) {
+ TF_RET_CHECK(shape().dimensions(i) ==
+ result_shape.dimensions(dimensions[i]));
+ }
+
+ std::unique_ptr<Literal> result = MakeUnique<Literal>(result_shape);
+
+ // scratch_source_index is temporary storage space for the computed index into
+ // the input literal. We put it here to avoid allocating an std::vector in
+ // every iteration of ShapeUtil::ForEachIndex.
+ std::vector<int64> scratch_source_index(shape().dimensions_size());
+
+ char* dest_data = static_cast<char*>(result->untyped_data());
+ const char* source_data = static_cast<const char*>(untyped_data());
+ const int64 primitive_size =
+ ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
+
+ ShapeUtil::ForEachIndex(
+ result_shape, [&](tensorflow::gtl::ArraySlice<int64> output_index) {
+ for (int64 i = 0; i < dimensions.size(); ++i) {
+ scratch_source_index[i] = output_index[dimensions[i]];
+ }
+ int64 dest_index = IndexUtil::MultidimensionalIndexToLinearIndex(
+ result_shape, output_index);
+ int64 source_index = IndexUtil::MultidimensionalIndexToLinearIndex(
+ shape(), scratch_source_index);
+ memcpy(dest_data + primitive_size * dest_index,
+ source_data + primitive_size * source_index, primitive_size);
+ return true;
+ });
+
+ return std::move(result);
+}
+
+StatusOr<std::unique_ptr<Literal>> LiteralBase::Reshape(
+ tensorflow::gtl::ArraySlice<int64> dimensions) const {
+ if (!ShapeUtil::IsArray(shape())) {
+ return InvalidArgument("Reshape does not support tuples.");
+ }
+ std::unique_ptr<Literal> output;
+ if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) {
+ output =
+ Relayout(LayoutUtil::GetDefaultLayoutForRank(ShapeUtil::Rank(shape())));
+ } else {
+ output = CloneToUnique();
+ }
+ // Because the layout is monotonic, we can simply reuse the same sequence of
+ // values without changing their order.
+ *output->mutable_shape_do_not_use() =
+ ShapeUtil::MakeShape(shape().element_type(), dimensions);
+
+ int64 elements_before = ShapeUtil::ElementsIn(shape());
+ int64 elements_after = ShapeUtil::ElementsIn(output->shape());
+ if (elements_before != elements_after) {
+ return InvalidArgument(
+ "Shapes before and after Literal::Reshape have different numbers "
+ "of elements: %s vs %s.",
+ ShapeUtil::HumanString(shape()).c_str(),
+ ShapeUtil::HumanString(output->shape()).c_str());
+ }
+ return std::move(output);
+}
+
+std::unique_ptr<Literal> LiteralBase::Transpose(
+ tensorflow::gtl::ArraySlice<int64> permutation) const {
+ CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose";
+ CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape())))
+ << "Given permutation is not a permutation of dimension numbers";
+ // To transpose the array, we just permute the dimensions and layout, and
+ // do a straight memory copy of the raw data set.
+ // This is considerably faster than iterating over every array element using
+ // the EachCell<>() and Set<>() APIs.
+ std::vector<int64> inverse_permutation = InversePermutation(permutation);
+ Shape permuted_shape =
+ ShapeUtil::PermuteDimensions(inverse_permutation, shape());
+ // Replace the layout with one affine to this shape, such that a
+ // transpose operation can be performed by leaving the flat values
+ // representation intact.
+ // For example, consider the shape F32[11,8]{1,0} under a {1,0} permutation.
+ // The shape with affine layout resulting from that operation will be
+ // F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), the
+ // most minor.
+ //
+ // Essentially, given MinMaj(Di) the position of the Di dimension within the
+ // minor to major vector, and given T(Di) the index that the original Di
+ // dimension has within the transposed array, a layout is affine if
+ // MinMaj(Di) == TMinMaj(T(Di)), with TMinMaj() being the minor to major
+ // vector of the affine layout.
+ CHECK(LayoutUtil::IsDenseArray(permuted_shape));
+ Layout* layout = permuted_shape.mutable_layout();
+ layout->clear_minor_to_major();
+ for (auto index : LayoutUtil::MinorToMajor(shape())) {
+ layout->add_minor_to_major(inverse_permutation[index]);
+ }
+ auto new_literal = MakeUnique<Literal>(permuted_shape);
+ DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal->shape()),
+ ShapeUtil::ByteSizeOf(shape()));
+ std::memcpy(new_literal->untyped_data(), untyped_data(), size_bytes());
+ return new_literal;
+}
+
+template <typename NativeT>
+std::unique_ptr<Literal> LiteralBase::SliceInternal(
+ const Shape& result_shape,
+ tensorflow::gtl::ArraySlice<int64> start_indices) const {
+ auto result_literal = MakeUnique<Literal>(result_shape);
+ DimensionVector new_indices(ShapeUtil::Rank(result_shape));
+ result_literal->EachCell<NativeT>(
+ [&](tensorflow::gtl::ArraySlice<int64> indices, NativeT /*value*/) {
+ for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) {
+ new_indices[i] = indices[i] + start_indices[i];
+ }
+ NativeT value = Get<NativeT>(new_indices);
+ result_literal->Set<NativeT>(indices, value);
+ });
+ return result_literal;
+}
+
+std::unique_ptr<Literal> LiteralBase::Slice(
+ tensorflow::gtl::ArraySlice<int64> start_indices,
+ tensorflow::gtl::ArraySlice<int64> limit_indices) const {
+ CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice";
+
+ DimensionVector result_dimensions;
+ for (int64 dnum = 0; dnum < ShapeUtil::Rank(shape()); ++dnum) {
+ CHECK_GE(start_indices[dnum], 0);
+ CHECK_LE(limit_indices[dnum], shape().dimensions(dnum))
+ << "dnum = " << dnum;
+ int64 dimension = limit_indices[dnum] - start_indices[dnum];
+ CHECK_GE(dimension, 0) << "dnum = " << dnum;
+ result_dimensions.push_back(dimension);
+ }
+ const auto result_shape =
+ ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions,
+ LayoutUtil::MinorToMajor(shape()));
+ switch (result_shape.element_type()) {
+ case F32:
+ return SliceInternal<float>(result_shape, start_indices);
+ case BF16:
+ return SliceInternal<bfloat16>(result_shape, start_indices);
+ case C64:
+ return SliceInternal<complex64>(result_shape, start_indices);
+ case S32:
+ return SliceInternal<int32>(result_shape, start_indices);
+ case U32:
+ return SliceInternal<uint32>(result_shape, start_indices);
+ default:
+ LOG(FATAL) << "not yet implemented: "
+ << PrimitiveType_Name(result_shape.element_type());
+ }
+}
+
+Literal LiteralBase::Clone() const {
+ Literal result(shape());
+ TF_CHECK_OK(result.CopyFrom(*this));
+ return result;
+}
+
+std::unique_ptr<Literal> LiteralBase::CloneToUnique() const {
+ auto result = MakeUnique<Literal>(shape());
+ TF_CHECK_OK(result->CopyFrom(*this));
+ return result;
+}
+
+string LiteralBase::GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index,
+ const ShapeIndex& shape_index) const {
+ const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
+ CHECK(LayoutUtil::IsDenseArray(subshape));
+ switch (subshape.element_type()) {
+ case PRED:
+ return Get<bool>(multi_index, shape_index) ? "true" : "false";
+ case S8:
+ return StrCat(Get<int8>(multi_index, shape_index));
+ case S16:
+ return StrCat(Get<int16>(multi_index, shape_index));
+ case S32:
+ return StrCat(Get<int32>(multi_index, shape_index));
+ case S64:
+ return StrCat(Get<int64>(multi_index, shape_index));
+ case U8:
+ return StrCat(Get<uint8>(multi_index, shape_index));
+ case U16:
+ return StrCat(Get<uint16>(multi_index, shape_index));
+ case U32:
+ return StrCat(Get<uint32>(multi_index, shape_index));
+ case U64:
+ return StrCat(Get<uint64>(multi_index, shape_index));
+ case F16:
+ return StrCat(static_cast<float>(Get<half>(multi_index, shape_index)));
+ case F32:
+ return StrCat(Get<float>(multi_index, shape_index));
+ case BF16:
+ return StrCat(
+ static_cast<float>(Get<bfloat16>(multi_index, shape_index)));
+ case F64:
+ return StrCat(Get<double>(multi_index, shape_index));
+ case C64: {
+ complex64 c = Get<complex64>(multi_index, shape_index);
+ return StrCat("(", c.real(), ", ", c.imag(), ")");
+ }
+ default:
+ LOG(FATAL) << PrimitiveType_Name(subshape.element_type());
+ }
+}
+
+string LiteralBase::GetSparseElementAsString(
+ int64 sparse_element_number, const ShapeIndex& shape_index) const {
+ const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
+ CHECK(LayoutUtil::IsSparseArray(subshape));
+ switch (subshape.element_type()) {
+ case PRED:
+ return GetSparseElement<bool>(sparse_element_number, shape_index)
+ ? "true"
+ : "false";
+ case S8:
+ return StrCat(GetSparseElement<int8>(sparse_element_number, shape_index));
+ case S16:
+ return StrCat(
+ GetSparseElement<int16>(sparse_element_number, shape_index));
+ case S32:
+ return StrCat(
+ GetSparseElement<int32>(sparse_element_number, shape_index));
+ case S64:
+ return StrCat(
+ GetSparseElement<int64>(sparse_element_number, shape_index));
+ case U8:
+ return StrCat(
+ GetSparseElement<uint8>(sparse_element_number, shape_index));
+ case U16:
+ return StrCat(
+ GetSparseElement<uint16>(sparse_element_number, shape_index));
+ case U32:
+ return StrCat(
+ GetSparseElement<uint32>(sparse_element_number, shape_index));
+ case U64:
+ return StrCat(
+ GetSparseElement<uint64>(sparse_element_number, shape_index));
+ case F16:
+ return StrCat(static_cast<float>(
+ GetSparseElement<half>(sparse_element_number, shape_index)));
+ case F32:
+ return StrCat(
+ GetSparseElement<float>(sparse_element_number, shape_index));
+ case BF16:
+ return StrCat(static_cast<float>(
+ GetSparseElement<bfloat16>(sparse_element_number, shape_index)));
+ case F64:
+ return StrCat(
+ GetSparseElement<double>(sparse_element_number, shape_index));
+ case C64: {
+ complex64 c =
+ GetSparseElement<complex64>(sparse_element_number, shape_index);
+ return StrCat("(", c.real(), ", ", c.imag(), ")");
+ }
+ default:
+ LOG(FATAL) << "Invalid element type for sparse arrays: "
+ << PrimitiveType_Name(subshape.element_type());
+ }
+}
+
+StatusOr<int64> LiteralBase::GetIntegralAsS64(
+ tensorflow::gtl::ArraySlice<int64> multi_index) const {
+ CHECK(LayoutUtil::IsDenseArray(shape()));
+ switch (shape().element_type()) {
+ case PRED:
+ return Get<bool>(multi_index);
+ case U8:
+ return Get<uint8>(multi_index);
+ case S32:
+ return Get<int32>(multi_index);
+ case S64:
+ return Get<int64>(multi_index);
+ case U32:
+ return Get<uint32>(multi_index);
+ case U64:
+ return Get<uint64>(multi_index);
+ default:
+ return FailedPrecondition(
+ "Array element type is not integral: %s",
+ PrimitiveType_Name(shape().element_type()).c_str());
+ }
+}
+
+size_t LiteralBase::Hash() const {
+ using tensorflow::Hash64;
+ using tensorflow::Hash64Combine;
+
+ size_t hash_value = ShapeUtil::Hash(shape());
+
+ ShapeUtil::ForEachSubshape(
+ shape(), [&](const Shape& subshape, const ShapeIndex& index) {
+ if (!ShapeUtil::IsArray(subshape)) {
+ return;
+ }
+
+ CHECK(LayoutUtil::IsDense(subshape.layout()));
+ hash_value = Hash64Combine(
+ hash_value, Hash64(static_cast<const char*>(untyped_data(index)),
+ size_bytes(index)));
+ });
+
+ return hash_value;
+}
+
+Status Literal::SetIntegralAsS64(tensorflow::gtl::ArraySlice<int64> multi_index,
+ int64 value) {
+ CHECK(LayoutUtil::IsDenseArray(shape()));
+ switch (shape().element_type()) {
+ case PRED:
+ Set<bool>(multi_index, value);
+ break;
+ case U8:
+ Set<uint8>(multi_index, value);
+ break;
+ case S32:
+ Set<int32>(multi_index, value);
+ break;
+ case S64:
+ Set<int64>(multi_index, value);
+ break;
+ case U32:
+ Set<uint32>(multi_index, value);
+ break;
+ case U64:
+ Set<uint64>(multi_index, value);
+ break;
+ default:
+ return FailedPrecondition(
+ "Array element type is not integral: %s",
+ PrimitiveType_Name(shape().element_type()).c_str());
+ }
+ return Status::OK();
+}
+
+tensorflow::gtl::ArraySlice<int64> LiteralBase::GetSparseIndex(
+ int64 sparse_element_number, const ShapeIndex& shape_index) const {
+ const Piece& p = piece(shape_index);
+ CHECK_GE(sparse_element_number, 0);
+ CHECK_LT(sparse_element_number, p.sparse_indices()->index_count());
+ return p.sparse_indices()->At(sparse_element_number);
+}
+
+void Literal::SortSparseElements(const ShapeIndex& shape_index) {
+ piece(shape_index).SortSparseElements();
+}
+
+void LiteralBase::Piece::SortSparseElements() {
+ switch (subshape().element_type()) {
+ case PRED:
+ SortSparseElementsInternal<bool>();
+ break;
+ case S8:
+ SortSparseElementsInternal<int8>();
+ break;
+ case U8:
+ SortSparseElementsInternal<uint8>();
+ break;
+ case S16:
+ SortSparseElementsInternal<int16>();
+ break;
+ case U16:
+ SortSparseElementsInternal<uint16>();
+ break;
+ case S32:
+ SortSparseElementsInternal<int32>();
+ break;
+ case U32:
+ SortSparseElementsInternal<uint32>();
+ break;
+ case S64:
+ SortSparseElementsInternal<int64>();
+ break;
+ case U64:
+ SortSparseElementsInternal<uint64>();
+ break;
+ case F32:
+ SortSparseElementsInternal<float>();
+ break;
+ case F64:
+ SortSparseElementsInternal<double>();
+ break;
+ case C64:
+ SortSparseElementsInternal<complex64>();
+ break;
+ case F16:
+ SortSparseElementsInternal<half>();
+ break;
+ case BF16:
+ SortSparseElementsInternal<bfloat16>();
+ break;
+ default:
+ LOG(FATAL) << "Element type not valid for sparse array: "
+ << PrimitiveType_Name(subshape().element_type());
+ }
+}
+
+template <typename NativeT>
+void LiteralBase::Piece::SortSparseElementsInternal() {
+ CHECK(LayoutUtil::IsSparseArray(subshape()));
+ int64 num_elements = sparse_indices()->index_count();
+ auto values = data<NativeT>();
+ CHECK_LE(num_elements, values.size());
+ sparse_indices()->SortWithValues(
+ tensorflow::gtl::MutableArraySlice<NativeT>(values.data(), num_elements));
+}
+
+namespace {
+
+void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
+ bool print_layout, std::vector<string>* pieces) {
+ const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
+ CHECK(LayoutUtil::HasLayout(literal.shape()));
+ CHECK(LayoutUtil::HasLayout(subshape));
+
+ auto shape_to_string = [print_layout](const Shape& shape) {
+ if (print_layout) {
+ return ShapeUtil::HumanStringWithLayout(shape);
+ } else {
+ return ShapeUtil::HumanString(shape);
+ }
+ };
+
+ // TODO(b/32894291): refactor this code to reduce code duplication.
+ if (ShapeUtil::IsTuple(subshape)) {
+ pieces->push_back(shape_to_string(subshape));
+ pieces->push_back(" (\n");
+ std::vector<string> tuple_pieces;
+ for (int i = 0; i < ShapeUtil::TupleElementCount(subshape); ++i) {
+ ShapeIndex element_index = shape_index;
+ element_index.push_back(i);
+ std::vector<string> element_pieces;
+ ToStringHelper(literal, element_index, print_layout, &element_pieces);
+ tuple_pieces.push_back(tensorflow::str_util::Join(element_pieces, ""));
+ }
+ pieces->push_back(tensorflow::str_util::Join(tuple_pieces, ",\n"));
+ pieces->push_back("\n)");
+ return;
+ }
+
+ if (ShapeUtil::IsToken(subshape)) {
+ pieces->push_back("token");
+ return;
+ }
+
+ if (LayoutUtil::IsSparseArray(subshape)) {
+ pieces->push_back(shape_to_string(subshape));
+ pieces->push_back("{");
+ int64 rank = ShapeUtil::Rank(subshape);
+ int64 num_elements = literal.sparse_element_count();
+ for (int64 i = 0; i < num_elements; ++i) {
+ if (i > 0) {
+ pieces->push_back(", ");
+ }
+ if (rank == 1) {
+ pieces->push_back(StrCat(literal.GetSparseIndex(i)[0]));
+ pieces->push_back(": ");
+ } else {
+ pieces->push_back("[");
+ pieces->push_back(
+ tensorflow::str_util::Join(literal.GetSparseIndex(i), ", "));
+ pieces->push_back("]: ");
+ }
+ pieces->push_back(literal.GetSparseElementAsString(i));
+ }
+ pieces->push_back("}");
+ return;
+ }
+
+ CHECK(LayoutUtil::IsDenseArray(subshape));
+
+ auto element_to_string =
+ [&](tensorflow::gtl::ArraySlice<int64> indices) -> string {
+ PrimitiveType element_type = subshape.element_type();
+ if (element_type == PRED) {
+ // We display predicates in a densely packed form.
+ return literal.Get<bool>(indices, shape_index) ? "1" : "0";
+ }
+ return ((!indices.empty() && indices.back() > 0) ? ", " : "") +
+ literal.GetAsString(indices, shape_index);
+ };
+
+ if (ShapeUtil::Rank(subshape) == 0) {
+ pieces->push_back(literal.GetAsString({}, shape_index));
+ } else if (ShapeUtil::Rank(subshape) == 1) {
+ pieces->push_back("{");
+ for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) {
+ pieces->push_back(element_to_string({i0}));
+ }
+ pieces->push_back("}");
+ } else if (ShapeUtil::Rank(subshape) == 2) {
+ pieces->push_back(shape_to_string(subshape));
+ pieces->push_back(" {\n");
+ for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) {
+ pieces->push_back(" { ");
+ for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) {
+ pieces->push_back(element_to_string({i0, i1}));
+ }
+ pieces->push_back(" ");
+ pieces->push_back(i0 == subshape.dimensions(0) - 1 ? "}\n" : "},\n");
+ }
+ pieces->push_back("}");
+ } else if (ShapeUtil::Rank(subshape) == 3) {
+ pieces->push_back(shape_to_string(subshape));
+ pieces->push_back(" {\n");
+ for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) {
+ pieces->push_back(i0 > 0 ? ",\n{" : "{");
+ for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) {
+ pieces->push_back(i1 > 0 ? ",\n { " : " { ");
+ for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) {
+ pieces->push_back(element_to_string({i0, i1, i2}));
+ }
+ pieces->push_back(" }");
+ }
+ pieces->push_back(" }");
+ }
+ pieces->push_back("\n}");
+ } else if (ShapeUtil::Rank(subshape) == 4) {
+ pieces->push_back(shape_to_string(subshape));
+ pieces->push_back(" {\n");
+ for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) {
+ pieces->push_back(Printf(" { /*i0=%lld*/\n", i0));
+ for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) {
+ pieces->push_back(Printf(" { /*i1=%lld*/\n", i1));
+ for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) {
+ pieces->push_back(" {");
+ for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) {
+ pieces->push_back(element_to_string({i0, i1, i2, i3}));
+ }
+ pieces->push_back(i2 == subshape.dimensions(2) - 1 ? "}\n" : "},\n");
+ }
+ pieces->push_back(i1 == subshape.dimensions(1) - 1 ? " }\n"
+ : " },\n");
+ }
+ pieces->push_back(i0 == subshape.dimensions(0) - 1 ? " }\n" : " },\n");
+ }
+ pieces->push_back("}");
+ } else if (ShapeUtil::Rank(subshape) == 5) {
+ pieces->push_back(shape_to_string(subshape));
+ pieces->push_back(" {\n");
+ for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) {
+ pieces->push_back(Printf(" { /*i0=%lld*/\n", i0));
+ for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) {
+ pieces->push_back(Printf(" { /*i1=%lld*/\n", i1));
+ for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) {
+ pieces->push_back(Printf(" { /*i2=%lld*/\n", i2));
+ for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) {
+ pieces->push_back(" {");
+ for (int64 i4 = 0; i4 < subshape.dimensions(4); ++i4) {
+ pieces->push_back(element_to_string({i0, i1, i2, i3, i4}));
+ }
+ pieces->push_back(i3 == subshape.dimensions(3) - 1 ? "}\n"
+ : "},\n");
+ }
+ pieces->push_back(i2 == subshape.dimensions(2) - 1 ? " }\n"
+ : " },\n");
+ }
+ pieces->push_back(i1 == subshape.dimensions(1) - 1 ? " }\n"
+ : " },\n");
+ }
+ pieces->push_back(i0 == subshape.dimensions(0) - 1 ? " }\n" : " },\n");
+ }
+ pieces->push_back("}");
+ } else {
+ pieces->push_back(shape_to_string(subshape));
+ pieces->push_back(" {");
+ literal.EachCellAsString(
+ [&](tensorflow::gtl::ArraySlice<int64> indices, const string& value) {
+ pieces->push_back(" ");
+ pieces->push_back(value);
+ });
+ pieces->push_back("}");
+ }
+}
+
+} // namespace
+
+int64 LiteralBase::sparse_element_count() const {
+ CHECK(LayoutUtil::IsSparseArray(shape()));
+ return sparse_indices()->index_count();
+}
+
+string LiteralBase::ToString(bool print_layout) const {
+ std::vector<string> pieces;
+ CHECK(LayoutUtil::HasLayout(this->shape()));
+ ToStringHelper(*this, {}, print_layout, &pieces);
+ return tensorflow::str_util::Join(pieces, "");
+}
+
+void LiteralBase::EachCellAsString(
+ const std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
+ const string& value)>& per_cell) const {
+ if (ShapeUtil::IsZeroElementArray(shape())) {
+ return;
+ }
+ std::vector<int64> indices = IndexUtil::LinearIndexToMultidimensionalIndex(
+ shape(), /*linear_index=*/0);
+ do {
+ per_cell(indices, GetAsString(indices));
+ } while (IndexUtil::BumpIndices(shape(), &indices));
+}
+
+namespace {
+template <typename NativeSrcT, typename NativeDestT, typename ConverterType>
+std::unique_ptr<Literal> ConvertBetweenNativeTypesWithConverter(
+ const LiteralBase& src_literal, const ConverterType& converter) {
+ CHECK(ShapeUtil::IsArray(src_literal.shape()));
+ auto result_literal = MakeUnique<Literal>(ShapeUtil::ChangeElementType(
+ src_literal.shape(),
+ primitive_util::NativeToPrimitiveType<NativeDestT>()));
+ auto src_data = src_literal.data<NativeSrcT>();
+ auto dest_data = result_literal->template data<NativeDestT>();
+ int64 num_elements = src_literal.element_count();
+
+ for (int64 i = 0; i < num_elements; ++i) {
+ dest_data[i] = converter(src_data[i]);
+ }
+ return result_literal;
+}
+
+template <typename NativeSrcT, typename NativeDestT>
+std::unique_ptr<Literal> ConvertBetweenNativeTypes(
+ const LiteralBase& src_literal) {
+ auto converter = [](NativeSrcT src) { return static_cast<NativeDestT>(src); };
+ return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
+ src_literal, converter);
+}
+
+template <typename NativeSrcT, typename NativeDestT>
+typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT)),
+ std::unique_ptr<Literal>>::type
+BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
+ auto converter = [](NativeSrcT src) {
+ return tensorflow::bit_cast<NativeDestT>(src);
+ };
+ return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
+ src_literal, converter);
+}
+
+// This template specialization is here to make the compiler happy. bit_cast has
+// a static check that the types are the same size. This specialization should
+// never be used because the source and destination types are checked for
+// identical sizes higher up.
+template <typename NativeSrcT, typename NativeDestT>
+typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)),
+ std::unique_ptr<Literal>>::type
+BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
+ LOG(FATAL) << "Invalid bitcast between types of different sizes.";
+}
+
+template <PrimitiveType primitive_src_type>
+std::unique_ptr<Literal> ConvertToC64(const LiteralBase& src_literal) {
+ CHECK(ShapeUtil::IsArray(src_literal.shape()));
+ auto result_literal = MakeUnique<Literal>(
+ ShapeUtil::ChangeElementType(src_literal.shape(), C64));
+ using NativeSrcT =
+ typename primitive_util::PrimitiveTypeToNative<primitive_src_type>::type;
+ tensorflow::gtl::ArraySlice<NativeSrcT> src_data =
+ src_literal.data<NativeSrcT>();
+ tensorflow::gtl::MutableArraySlice<complex64> dest_data =
+ result_literal->data<complex64>();
+ int64 num_elements = src_literal.element_count();
+ for (int64 i = 0; i < num_elements; ++i) {
+ dest_data[i] = complex64(static_cast<float>(src_data[i]), 0);
+ }
+ return result_literal;
+}
+
+template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type>
+std::unique_ptr<Literal> ConvertIfTypesMatch(const LiteralBase& src_literal,
+ bool bitcast) {
+ CHECK_EQ(primitive_src_type, src_literal.shape().element_type());
+ if (bitcast) {
+ return BitcastBetweenNativeTypes<
+ typename primitive_util::PrimitiveTypeToNative<
+ primitive_src_type>::type,
+ typename primitive_util::PrimitiveTypeToNative<
+ primitive_dest_type>::type>(src_literal);
+ } else {
+ return ConvertBetweenNativeTypes<
+ typename primitive_util::PrimitiveTypeToNative<
+ primitive_src_type>::type,
+ typename primitive_util::PrimitiveTypeToNative<
+ primitive_dest_type>::type>(src_literal);
+ }
+}
+
+template <PrimitiveType primitive_src_type>
+StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches(
+ const LiteralBase& src_literal, PrimitiveType primitive_dest_type,
+ bool bitcast) {
+ switch (primitive_dest_type) {
+#define CONVERT_IF_TYPES_MATCH(type) \
+ case (type): \
+ return ConvertIfTypesMatch<primitive_src_type, (type)>(src_literal, \
+ bitcast);
+ CONVERT_IF_TYPES_MATCH(PRED)
+ CONVERT_IF_TYPES_MATCH(S8)
+ CONVERT_IF_TYPES_MATCH(S32)
+ CONVERT_IF_TYPES_MATCH(S64)
+ CONVERT_IF_TYPES_MATCH(U8)
+ CONVERT_IF_TYPES_MATCH(U32)
+ CONVERT_IF_TYPES_MATCH(U64)
+ CONVERT_IF_TYPES_MATCH(F16)
+ CONVERT_IF_TYPES_MATCH(F32)
+ CONVERT_IF_TYPES_MATCH(F64)
+ CONVERT_IF_TYPES_MATCH(BF16)
+#undef CONVERT_IF_TYPES_MATCH
+ case C64:
+ if (!bitcast) {
+ return ConvertToC64<primitive_src_type>(src_literal);
+ }
+ break;
+ // Other types are not yet supported.
+ default:
+ break;
+ }
+ return Unimplemented(
+ "Converting from type %s to type %s is not implemented.",
+ PrimitiveType_Name(src_literal.shape().element_type()).c_str(),
+ PrimitiveType_Name(primitive_dest_type).c_str());
+}
+
+StatusOr<std::unique_ptr<Literal>> ConvertSwitch(
+ const LiteralBase& literal, PrimitiveType primitive_dest_type,
+ bool bitcast) {
+ TF_RET_CHECK(ShapeUtil::IsArray(literal.shape()));
+ if (literal.shape().element_type() == primitive_dest_type) {
+ return literal.CloneToUnique();
+ }
+ switch (literal.shape().element_type()) {
+#define CONVERT_IF_DEST_TYPE_MATCHES(type) \
+ case (type): \
+ return ConvertIfDestTypeMatches<(type)>(literal, primitive_dest_type, \
+ bitcast);
+ CONVERT_IF_DEST_TYPE_MATCHES(PRED)
+ CONVERT_IF_DEST_TYPE_MATCHES(S8)
+ CONVERT_IF_DEST_TYPE_MATCHES(S32)
+ CONVERT_IF_DEST_TYPE_MATCHES(S64)
+ CONVERT_IF_DEST_TYPE_MATCHES(U8)
+ CONVERT_IF_DEST_TYPE_MATCHES(U32)
+ CONVERT_IF_DEST_TYPE_MATCHES(U64)
+ CONVERT_IF_DEST_TYPE_MATCHES(F16)
+ CONVERT_IF_DEST_TYPE_MATCHES(F32)
+ CONVERT_IF_DEST_TYPE_MATCHES(F64)
+ CONVERT_IF_DEST_TYPE_MATCHES(BF16)
+#undef CONVERT_IF_DEST_TYPE_MATCHES
+ // Other types are not yet supported.
+ default:
+ return Unimplemented(
+ "%s from type %s to type %s is not implemented.",
+ (bitcast ? "Bitcast converting" : "Converting"),
+ PrimitiveType_Name(literal.shape().element_type()).c_str(),
+ PrimitiveType_Name(primitive_dest_type).c_str());
+ }
+}
+
+} // namespace
+
+StatusOr<std::unique_ptr<Literal>> LiteralBase::Convert(
+ PrimitiveType primitive_dest_type) const {
+ return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false);
+}
+
+StatusOr<std::unique_ptr<Literal>> LiteralBase::BitcastConvert(
+ PrimitiveType primitive_dest_type) const {
+ if (primitive_util::BitWidth(shape().element_type()) !=
+ primitive_util::BitWidth(primitive_dest_type)) {
+ return InvalidArgument(
+ "Cannot bitcast convert from %s to %s, bit widths are different: %d != "
+ "%d",
+ PrimitiveType_Name(shape().element_type()).c_str(),
+ PrimitiveType_Name(primitive_dest_type).c_str(),
+ primitive_util::BitWidth(shape().element_type()),
+ primitive_util::BitWidth(primitive_dest_type));
+ }
+ return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/true);
+}
+
+StatusOr<std::unique_ptr<Literal>> LiteralBase::ConvertToShape(
+ const Shape& dest_shape, bool round_f32_to_bf16) const {
+ if (!ShapeUtil::IsTuple(dest_shape)) {
+ if (round_f32_to_bf16 && shape().element_type() == F32 &&
+ dest_shape.element_type() == BF16) {
+ auto converter = [](float src) {
+ return tensorflow::bfloat16::round_to_bfloat16(src);
+ };
+ return ConvertBetweenNativeTypesWithConverter<float, bfloat16>(*this,
+ converter);
+ }
+ return Convert(dest_shape.element_type());
+ }
+ std::vector<Literal> elements;
+ for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) {
+ auto element = LiteralSlice(*this, {i});
+ TF_ASSIGN_OR_RETURN(
+ auto new_element,
+ element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i})));
+ elements.push_back(std::move(*new_element));
+ }
+ auto converted = MakeUnique<Literal>();
+ *converted = Literal::MoveIntoTuple(&elements);
+ return std::move(converted);
+}
+
+/* static */ Literal Literal::MoveIntoTuple(
+ tensorflow::gtl::MutableArraySlice<Literal> elements) {
+ std::vector<Shape> element_shapes;
+ for (const Literal& element : elements) {
+ element_shapes.push_back(element.shape());
+ }
+ Literal literal(ShapeUtil::MakeTupleShape(element_shapes),
+ /*allocate_arrays=*/false);
+ for (int i = 0; i < elements.size(); ++i) {
+ TF_CHECK_OK(
+ literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i}));
+ }
+ return literal;
+}
+
+template <typename NativeT>
+bool LiteralBase::Piece::EqualElementsInternal(
+ const LiteralBase::Piece& other, std::vector<int64>* multi_index) const {
+ if (multi_index->size() == ShapeUtil::Rank(subshape())) {
+ return (Get<NativeT>(*multi_index) == other.Get<NativeT>(*multi_index));
+ }
+ for (int64 i = 0; i < subshape().dimensions(multi_index->size()); ++i) {
+ multi_index->push_back(i);
+ if (!EqualElementsInternal<NativeT>(other, multi_index)) {
+ return false;
+ }
+ multi_index->pop_back();
+ }
+ return true;
+}
+
+bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const {
+ DCHECK(ShapeUtil::Compatible(subshape(), other.subshape()));
+
+ std::vector<int64> multi_index;
+ switch (subshape().element_type()) {
+ case PRED:
+ return EqualElementsInternal<bool>(other, &multi_index);
+ case U8:
+ return EqualElementsInternal<uint8>(other, &multi_index);
+ case S32:
+ return EqualElementsInternal<int32>(other, &multi_index);
+ case S64:
+ return EqualElementsInternal<int64>(other, &multi_index);
+ case U32:
+ return EqualElementsInternal<uint32>(other, &multi_index);
+ case U64:
+ return EqualElementsInternal<uint64>(other, &multi_index);
+ case F32:
+ return EqualElementsInternal<float>(other, &multi_index);
+ case F64:
+ return EqualElementsInternal<double>(other, &multi_index);
+ case F16:
+ return EqualElementsInternal<half>(other, &multi_index);
+ case BF16:
+ return EqualElementsInternal<bfloat16>(other, &multi_index);
+ case C64:
+ return EqualElementsInternal<complex64>(other, &multi_index);
+ default:
+ LOG(FATAL) << "Unimplemented: LiteralBase::Piece::EqualElements for type "
+ << PrimitiveType_Name(subshape().element_type());
+ }
+}
+
+bool LiteralBase::operator==(const LiteralBase& other) const {
+ if (!ShapeUtil::Compatible(shape(), other.shape())) {
+ return false;
+ }
+
+ return root_piece().ForEachSubpieceWithBool(
+ [&](const ShapeIndex& index, const Piece& piece) {
+ if (!ShapeUtil::IsArray(piece.subshape())) {
+ return true;
+ }
+
+ const Piece& other_piece = other.piece(index);
+ if (!piece.EqualElements(other_piece)) {
+ return false;
+ }
+ return true;
+ });
+}
+
+namespace {
+
+template <typename NativeT>
+static bool AllElementsEqualValue(tensorflow::gtl::ArraySlice<NativeT> data,
+ NativeT value) {
+ for (int64 i = 0; i < data.size(); ++i) {
+ if (data[i] != value) {
+ return false;
+ }
+ }
+ return true;
+}
+
+} // namespace
+
+bool LiteralBase::IsAll(int8 value) const {
+ return root_piece().ForEachSubpieceWithBool([&](const ShapeIndex& index,
+ const Piece& piece) {
+ if (!ShapeUtil::IsArray(piece.subshape())) {
+ return true;
+ }
+
+ auto piece_is_all = [&]() {
+ switch (shape().element_type()) {
+ case U8:
+ if (value >= 0) {
+ return AllElementsEqualValue<uint8>(piece.data<uint8>(), value);
+ }
+ return false;
+ case U32:
+ if (value >= 0) {
+ return AllElementsEqualValue<uint32>(piece.data<uint32>(), value);
+ }
+ return false;
+ case U64:
+ if (value >= 0) {
+ return AllElementsEqualValue<uint64>(piece.data<uint64>(), value);
+ }
+ return false;
+ case S8:
+ return AllElementsEqualValue<int8>(piece.data<int8>(), value);
+ case S32:
+ return AllElementsEqualValue<int32>(piece.data<int32>(), value);
+ case S64:
+ return AllElementsEqualValue<int64>(piece.data<int64>(), value);
+ case F32:
+ return AllElementsEqualValue<float>(piece.data<float>(), value);
+ case F64:
+ return AllElementsEqualValue<double>(piece.data<double>(), value);
+ case F16:
+ return AllElementsEqualValue<half>(piece.data<half>(),
+ static_cast<half>(value));
+ case BF16:
+ return AllElementsEqualValue<bfloat16>(piece.data<bfloat16>(),
+ static_cast<bfloat16>(value));
+ case PRED:
+ if (value == 0) {
+ return AllElementsEqualValue<bool>(piece.data<bool>(), false);
+ }
+ if (value == 1) {
+ return AllElementsEqualValue<bool>(piece.data<bool>(), true);
+ }
+ return false;
+ default:
+ return false;
+ }
+ return false;
+ };
+
+ if (!piece_is_all()) {
+ return false;
+ }
+ return true;
+ });
+}
+
+bool LiteralBase::IsAllFloat(float value) const {
+ return root_piece().ForEachSubpieceWithBool(
+ [&](const ShapeIndex& index, const Piece& piece) {
+ if (!ShapeUtil::IsArray(piece.subshape())) {
+ return true;
+ }
+
+ auto piece_is_all = [&]() {
+ switch (shape().element_type()) {
+ case F32:
+ return AllElementsEqualValue<float>(piece.data<float>(), value);
+ case F64:
+ return AllElementsEqualValue<double>(piece.data<double>(), value);
+ case F16:
+ return AllElementsEqualValue<half>(piece.data<half>(),
+ static_cast<half>(value));
+ case BF16:
+ return AllElementsEqualValue<bfloat16>(
+ piece.data<bfloat16>(), static_cast<bfloat16>(value));
+ default:
+ return false;
+ }
+ };
+ if (!piece_is_all()) {
+ return false;
+ }
+ return true;
+ });
+}
+
+bool LiteralBase::IsAllComplex(complex64 value) const {
+ switch (shape().element_type()) {
+ case C64:
+ return AllElementsEqualValue<complex64>(root_piece().data<complex64>(),
+ value);
+ default:
+ return false;
+ }
+}
+
+bool LiteralBase::IsAllFirst() const {
+ return root_piece().ForEachSubpieceWithBool(
+ [&](const ShapeIndex& index, const Piece& piece) {
+ if (!ShapeUtil::IsArray(piece.subshape())) {
+ return true;
+ }
+
+ // Empty shapes are not all the first element since there is no first
+ // element.
+ if (ShapeUtil::IsZeroElementArray(piece.subshape())) {
+ return false;
+ }
+ auto piece_is_all = [&]() {
+ switch (piece.subshape().element_type()) {
+ case PRED: {
+ auto data = piece.data<bool>();
+ return AllElementsEqualValue<bool>(data, data[0]);
+ }
+ // 8 bit types
+ case S8: {
+ auto data = piece.data<int8>();
+ return AllElementsEqualValue<int8>(data, data[0]);
+ }
+ case U8: {
+ auto data = piece.data<uint8>();
+ return AllElementsEqualValue<uint8>(data, data[0]);
+ }
+ // 16 bit types
+ case BF16: {
+ auto data = piece.data<bfloat16>();
+ return AllElementsEqualValue<bfloat16>(data, data[0]);
+ }
+ case F16: {
+ auto data = piece.data<half>();
+ return AllElementsEqualValue<half>(data, data[0]);
+ }
+ case S16: {
+ auto data = piece.data<int16>();
+ return AllElementsEqualValue<int16>(data, data[0]);
+ }
+ case U16: {
+ auto data = piece.data<uint16>();
+ return AllElementsEqualValue<uint16>(data, data[0]);
+ }
+ // 32 bit types
+ case F32: {
+ auto data = piece.data<float>();
+ return AllElementsEqualValue<float>(data, data[0]);
+ }
+ case U32: {
+ auto data = piece.data<uint32>();
+ return AllElementsEqualValue<uint32>(data, data[0]);
+ }
+ case S32: {
+ auto data = piece.data<int32>();
+ return AllElementsEqualValue<int32>(data, data[0]);
+ }
+ // 64 bit types
+ case C64: {
+ auto data = piece.data<complex64>();
+ return AllElementsEqualValue<complex64>(data, data[0]);
+ }
+ case F64: {
+ auto data = piece.data<double>();
+ return AllElementsEqualValue<double>(data, data[0]);
+ }
+ case S64: {
+ auto data = piece.data<int64>();
+ return AllElementsEqualValue<int64>(data, data[0]);
+ }
+ case U64: {
+ auto data = piece.data<uint64>();
+ return AllElementsEqualValue<uint64>(data, data[0]);
+ }
+ default:
+ return false;
+ }
+ };
+
+ if (!piece_is_all()) {
+ return false;
+ }
+ return true;
+ });
+}
+
+bool LiteralBase::IsZero(tensorflow::gtl::ArraySlice<int64> indices) const {
+ CHECK(ShapeUtil::IsArray(shape()));
+ switch (shape().element_type()) {
+ case U8:
+ return Get<uint8>(indices) == 0;
+ case U32:
+ return Get<uint32>(indices) == 0;
+ case U64:
+ return Get<uint64>(indices) == 0;
+ case S8:
+ return Get<int8>(indices) == 0;
+ case S32:
+ return Get<int32>(indices) == 0;
+ case S64:
+ return Get<int64>(indices) == 0;
+ case F32:
+ return Get<float>(indices) == 0.0f;
+ case F64:
+ return Get<double>(indices) == 0.0;
+ case C64:
+ return Get<complex64>(indices) == complex64(0.0f, 0.0f);
+ case F16:
+ return Get<half>(indices) == static_cast<half>(0.0f);
+ case BF16:
+ return Get<bfloat16>(indices) == static_cast<bfloat16>(0.0f);
+ case PRED:
+ return Get<bool>(indices) == false;
+ default:
+ LOG(FATAL) << "Input literal must be an array.";
+ }
+}
+
+namespace {
+
+template <typename RepeatedFieldT, typename NativeT>
+void CopyToRepeatedField(RepeatedFieldT* dest,
+ const tensorflow::gtl::ArraySlice<NativeT> src) {
+ *dest = RepeatedFieldT(src.begin(), src.end());
+}
+
+} // namespace
+
+void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const {
+ *proto->mutable_shape() = subshape();
+ switch (subshape().element_type()) {
+ case PRED:
+ CopyToRepeatedField(proto->mutable_preds(), data<bool>());
+ break;
+ case U8:
+ proto->set_u8s(static_cast<const unsigned char*>(data<uint8>().data()),
+ element_count());
+ break;
+ case U32:
+ CopyToRepeatedField(proto->mutable_u32s(), data<uint32>());
+ break;
+ case U64:
+ CopyToRepeatedField(proto->mutable_u64s(), data<uint64>());
+ break;
+ case S32:
+ CopyToRepeatedField(proto->mutable_s32s(), data<int32>());
+ break;
+ case S64:
+ CopyToRepeatedField(proto->mutable_s64s(), data<int64>());
+ break;
+ case F16:
+ *proto->mutable_f16s() = string(
+ reinterpret_cast<const char*>(data<half>().data()), size_bytes());
+ if (!kLittleEndian) {
+ ConvertEndianShort(proto->mutable_f16s());
+ }
+ break;
+ case BF16:
+ *proto->mutable_bf16s() = string(
+ reinterpret_cast<const char*>(data<bfloat16>().data()), size_bytes());
+ if (!kLittleEndian) {
+ ConvertEndianShort(proto->mutable_bf16s());
+ }
+ break;
+ case F32:
+ CopyToRepeatedField(proto->mutable_f32s(), data<float>());
+ break;
+ case F64:
+ CopyToRepeatedField(proto->mutable_f64s(), data<double>());
+ break;
+ case C64:
+ for (complex64 value : data<complex64>()) {
+ proto->add_c64s(value.real());
+ proto->add_c64s(value.imag());
+ }
+ break;
+ case TUPLE:
+ case TOKEN:
+ // Nothing to do but assign the shape which is done above.
+ return;
+ default:
+ LOG(FATAL) << "Unhandled primitive type " << subshape().element_type();
+ }
+}
+
+const void* LiteralBase::Piece::untyped_data() const {
+ CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
+ return buffer();
+}
+
+void* LiteralBase::Piece::untyped_data() {
+ CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
+ return buffer();
+}
+
+namespace {
+
+template <typename RepeatedFieldT, typename NativeT>
+Status CopyFromRepeatedField(tensorflow::gtl::MutableArraySlice<NativeT> dest,
+ const RepeatedFieldT& src) {
+ if (dest.size() != src.size()) {
+ return InvalidArgument(
+ "Expected %lu elements in LiteralProto repeated field, has %d",
+ dest.size(), src.size());
+ }
+ std::copy(src.begin(), src.end(), dest.begin());
+ return Status::OK();
+}
+
+} // namespace
+
+Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) {
+ // These conditions should have been checked in Literal::CreateFromProto.
+ TF_RET_CHECK(proto.has_shape());
+ TF_RET_CHECK(LayoutUtil::HasLayout(proto.shape()));
+ TF_RET_CHECK(ShapeUtil::Equal(proto.shape(), subshape()));
+
+ switch (subshape().element_type()) {
+ case PRED:
+ TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<bool>(), proto.preds()));
+ break;
+ case U8: {
+ auto u8_data = data<uint8>();
+ TF_RET_CHECK(proto.u8s().size() == u8_data.size());
+ std::copy(proto.u8s().begin(), proto.u8s().end(), u8_data.begin());
+ } break;
+ case S32:
+ TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<int32>(), proto.s32s()));
+ break;
+ case S64:
+ TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<int64>(), proto.s64s()));
+ break;
+ case U32:
+ TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<uint32>(), proto.u32s()));
+ break;
+ case U64:
+ TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<uint64>(), proto.u64s()));
+ break;
+ case F16: {
+ const string& s(proto.f16s());
+ TF_RET_CHECK(data<half>().size() * sizeof(half) == s.size());
+ memcpy(untyped_data(), s.data(), s.size());
+ if (!kLittleEndian) {
+ ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
+ }
+ } break;
+
+ case BF16: {
+ const string& s(proto.bf16s());
+ TF_RET_CHECK(data<bfloat16>().size() * sizeof(bfloat16) == s.size());
+ memcpy(untyped_data(), s.data(), s.size());
+ if (!kLittleEndian) {
+ ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
+ }
+ } break;
+ case F32:
+ TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<float>(), proto.f32s()));
+ break;
+ case F64:
+ TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<double>(), proto.f64s()));
+ break;
+ case C64: {
+ auto complex_data = data<complex64>();
+ TF_RET_CHECK(proto.c64s_size() == complex_data.size() * 2);
+ for (int64 i = 0; i < complex_data.size(); ++i) {
+ complex_data[i] = complex64{proto.c64s(i * 2), proto.c64s(i * 2 + 1)};
+ }
+ } break;
+ case TUPLE:
+ LOG(FATAL) << "Should not be called on tuple shapes: "
+ << ShapeUtil::HumanString(subshape());
+ break;
+ default:
+ LOG(FATAL) << "Unhandled primitive type " << subshape().element_type();
+ }
+ return Status::OK();
+}
+
+LiteralProto LiteralBase::ToProto() const {
+ LiteralProto proto;
+ root_piece().ForEachSubpiece(
+ [&](const ShapeIndex& index, const Piece& piece) {
+ LiteralProto* proto_piece = &proto;
+ for (int64 i : index) {
+ while (proto_piece->tuple_literals_size() <= i) {
+ proto_piece->add_tuple_literals();
+ }
+ proto_piece = proto_piece->mutable_tuple_literals(i);
+ }
+ piece.WriteToProto(proto_piece);
+ });
+
+ if (LayoutUtil::IsSparseArray(shape())) {
+ CopyToRepeatedField(proto.mutable_sparse_indices(),
+ sparse_indices()->data());
+ }
+
+ return proto;
+}
+
+const void* LiteralBase::untyped_data(const ShapeIndex& shape_index) const {
+ return piece(shape_index).untyped_data();
+}
+
+void* Literal::untyped_data(const ShapeIndex& shape_index) {
+ return piece(shape_index).untyped_data();
+}
+
+int64 LiteralBase::size_bytes(const ShapeIndex& shape_index) const {
+ return piece(shape_index).size_bytes();
+}
+
+string LiteralBase::GetR1U8AsString() const {
+ CHECK(ShapeUtil::IsArray(shape()));
+ CHECK_EQ(ShapeUtil::Rank(shape()), 1);
+ CHECK_EQ(shape().element_type(), U8);
+ return string(tensorflow::bit_cast<const char*>(data<uint8>().data()),
+ ShapeUtil::ElementsIn(shape()));
+}
+
+void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) {
+ CHECK(ShapeUtil::IsTuple(shape));
+ for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
+ const Shape& subshape = shape.tuple_shapes(i);
+
+ auto child_piece = Piece();
+ child_piece.set_subshape(&subshape);
+
+ if (ShapeUtil::IsTuple(subshape)) {
+ BuildPieceSubtree(subshape, &child_piece);
+ }
+
+ piece->emplace_back(std::move(child_piece));
+ }
+}
+
+LiteralSlice::LiteralSlice(const LiteralBase& literal)
+ : LiteralBase(), root_piece_(&literal.root_piece()) {}
+
+LiteralSlice::LiteralSlice(const LiteralBase& literal,
+ const ShapeIndex& view_root)
+ : LiteralBase(), root_piece_(&literal.piece(view_root)) {}
+
+BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape)
+ : LiteralBase(), shape_(MakeUnique<Shape>(shape)) {
+ CHECK(ShapeUtil::IsArray(*shape_));
+ CHECK(LayoutUtil::HasLayout(*shape_));
+
+ root_piece_ = Piece();
+ root_piece_.set_buffer(const_cast<char*>(src_buf_ptr));
+ root_piece_.set_subshape(shape_.get());
+}
+
+BorrowingLiteral::BorrowingLiteral(
+ tensorflow::gtl::ArraySlice<const char*> src_buf_ptrs, const Shape& shape)
+ : LiteralBase(), shape_(MakeUnique<Shape>(shape)) {
+ CHECK(ShapeUtil::IsTuple(*shape_));
+ CHECK(!ShapeUtil::IsNestedTuple(*shape_));
+ CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_));
+ root_piece_ = Piece();
+ root_piece_.set_subshape(shape_.get());
+ BuildPieceSubtree(*shape_, &root_piece_);
+
+ for (int i = 0; i < src_buf_ptrs.size(); ++i) {
+ const auto& src_shape = shape_->tuple_shapes(i);
+ CHECK(ShapeUtil::IsArray(src_shape));
+ root_piece_.child(i).set_buffer(const_cast<char*>(src_buf_ptrs[i]));
+ }
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h
new file mode 100644
index 0000000000..dd67dfa8d4
--- /dev/null
+++ b/tensorflow/compiler/xla/literal.h
@@ -0,0 +1,1152 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_LITERAL_H_
+#define TENSORFLOW_COMPILER_XLA_LITERAL_H_
+
+#include <functional>
+#include <initializer_list>
+#include <iterator>
+#include <memory>
+#include <ostream>
+#include <string>
+#include <type_traits>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/array3d.h"
+#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/index_util.h"
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/primitive_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/sparse_index_array.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/bitmap.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+// Forward declare Literal and LiteralSlice class to be used by the creation
+// methods in the base class.
+class Literal;
+class LiteralSlice;
+
+// Abstract base class for literals.
+class LiteralBase {
+ public:
+ virtual ~LiteralBase() = 0;
+
+ // Literals are equal if they have compatible shapes and the same data
+ // values. Layout is not compared.
+ bool operator==(const LiteralBase& other) const;
+ bool operator!=(const LiteralBase& other) const { return !(*this == other); }
+
+ // Returns the shape of the literal.
+ const Shape& shape() const { return root_piece().subshape(); }
+
+ // Serialize to proto.
+ LiteralProto ToProto() const;
+
+ // Returns an ArraySlice of the array for this literal for the given NativeT
+ // (e.g., float). CHECKs if the subshape of the literal at the given
+ // ShapeIndex is not array. See primitive_util.h for the mapping from XLA type
+ // to native type.
+ template <typename NativeT>
+ tensorflow::gtl::ArraySlice<NativeT> data(
+ const ShapeIndex& shape_index = {}) const;
+
+ // Returns a const pointer to the sparse index array. Returns nullptr if the
+ // literal is not a sparse array.
+ const SparseIndexArray* sparse_indices(
+ const ShapeIndex& shape_index = {}) const;
+
+ // Returns a const pointer to (or size of) the underlying buffer holding the
+ // array at the given shape index. CHECKs if the subshape of the literal at
+ // the given ShapeIndex is not array.
+ const void* untyped_data(const ShapeIndex& shape_index = {}) const;
+ int64 size_bytes(const ShapeIndex& shape_index = {}) const;
+
+ // Returns this literal's data as a string. This literal must be a rank-1 U8
+ // array.
+ string GetR1U8AsString() const;
+
+ // Returns a string representation of the literal value.
+ // Warning: this function can take minutes for multi-million element Literals.
+ string ToString(bool print_layout = false) const;
+
+ // Gets an element in the literal at the given index. The multi_index is
+ // CHECKed against the dimension sizes.
+ template <typename NativeT>
+ NativeT Get(tensorflow::gtl::ArraySlice<int64> multi_index,
+ const ShapeIndex& shape_index) const;
+ // Overloads of Get for array literals. CHECKs if the literal is not
+ // array-shaped and dense.
+ template <typename NativeT>
+ NativeT Get(tensorflow::gtl::ArraySlice<int64> multi_index) const;
+
+ // Returns the element value at index (0, ..., 0), however many zeroes are
+ // required for that index.
+ template <typename NativeT>
+ NativeT GetFirstElement() const;
+
+ // As Get(), but determines the correct type and converts the value
+ // into text.
+ string GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index,
+ const ShapeIndex& shape_index = {}) const;
+ // As GetSparseElement(), but determines the correct type and converts the
+ // value into text.
+ string GetSparseElementAsString(int64 sparse_element_number,
+ const ShapeIndex& shape_index = {}) const;
+ // As Get(), but determines the correct type and converts the value into
+ // int64. This literal must be an array.
+ StatusOr<int64> GetIntegralAsS64(
+ tensorflow::gtl::ArraySlice<int64> multi_index) const;
+
+ // Returns the multi-index of the element in a sparse literal at the given
+ // sparse element number. The sparse element number is the position with in
+ // the sparse array's list of (index, value) pairs, and is checked against the
+ // total number of (index, value) pairs in the sparse array.
+ tensorflow::gtl::ArraySlice<int64> GetSparseIndex(
+ int64 sparse_element_number, const ShapeIndex& shape_index = {}) const;
+
+ // Returns the value of the element in a sparse literal at the given sparse
+ // element number. The sparse element number is the position with in the
+ // sparse array's list of (index, value) pairs, and is checked against the
+ // total number of (index, value) pairs in the sparse array.
+ template <typename NativeT>
+ NativeT GetSparseElement(int64 sparse_element_number,
+ const ShapeIndex& shape_index = {}) const;
+
+ // Invokes the "per cell" callback for each element in the provided
+ // literal with the element's indices and a string representation of
+ // the element's value.
+ //
+ // This function is useful if you want a polymorphic representation
+ // of the tensor's elements (turning it to a string for something
+ // like representation in a protobuf).
+ //
+ // This literal must have a dense layout.
+ void EachCellAsString(
+ const std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
+ const string& value)>& per_cell) const;
+ template <typename NativeT>
+ void EachCell(std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
+ NativeT value)>
+ per_cell) const;
+
+ // Returns whether every element in this literal is equal to value.
+ //
+ // value is an int8 because we expect this to be called with small
+ // compile-time constants (0, -1, etc.) and so that whatever value you pass
+ // can be represented exactly by floating-point types as small as 16 bits.
+ //
+ // If value doesn't fit in this literal's type, returns false. Values of 1/0
+ // are considered equal to true/false; other values are not considered equal
+ // to true. Also if this literal is not array-shaped false is returned.
+ bool IsAll(int8 value) const;
+
+ // Like IsAll(const Literal&, int8), except we check whether the literal is
+ // equal to a particular floating-point number.
+ //
+ // If the literal is not a floating-point value, this always returns false.
+ //
+ // This casts value to the type of literal, then compares using ==. The usual
+ // admonishments about floating-point equality checks apply. We expect you to
+ // use this to check for values that can be expressed precisely as a float,
+ // e.g. -0.5. Also if this literal is not array-shaped false is returned.
+ bool IsAllFloat(float value) const;
+
+ // Like IsAll(const Literal&, int8), except we check whether the literal is
+ // equal to a particular complex number.
+ //
+ // If the literal is not a complex value, this always returns false.
+ //
+ // This casts value to the type of literal, then compares using ==. The usual
+ // admonishments about floating-point equality checks apply. We expect you to
+ // use this to check for complex values that can be expressed precisely as
+ // float pairs e.g. (-0.5, 1.0).
+ //
+ // This literal must have a dense layout.
+ bool IsAllComplex(complex64 value) const;
+
+ // Literal consists entirely of the first element of the literal.
+ bool IsAllFirst() const;
+
+ // Returns whether this literal is zero at the specified index. This literal
+ // must be an array with a dense layout.
+ bool IsZero(tensorflow::gtl::ArraySlice<int64> indices) const;
+
+ // Returns the count of the elements in the array at the given shape index in
+ // this literal.
+ int64 element_count(const ShapeIndex& index = {}) const {
+ return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index));
+ }
+
+ // Returns the count of the elements in the sparse array at the given shape
+ // index in this literal, which will be no larger than
+ // LayoutUtil::MaxSparseElements(SetSubshape(shape(), index).layout()).
+ int64 sparse_element_count() const;
+
+ // Compute a hash for this literal. This literal must not be a sparse tensor
+ // or a tuple containing a sparse tensor.
+ size_t Hash() const;
+
+ // Converts this literal to the given shape. Returns an error is the
+ // conversion is not possible.
+ //
+ // round_f32_to_bf16: if true, converting F32 elements to BF16 uses rounding
+ // instead of truncation; otherwise, truncation is used.
+ //
+ // TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes
+ // the default behavior.
+ StatusOr<std::unique_ptr<Literal>> ConvertToShape(
+ const Shape& dest_shape, bool round_f32_to_bf16 = false) const;
+
+ // Converts this literal to another primitive type using a bitcast
+ // conversion. The to and from primitive types must have the same bit
+ // width. Returns an error if the conversion is not possible. This literal
+ // must be array-shaped.
+ StatusOr<std::unique_ptr<Literal>> BitcastConvert(
+ PrimitiveType primitive_dest_type) const;
+
+ // Converts this literal to another primitive type. Returns an error if the
+ // conversion is not possible. This literal must be array-shaped.
+ StatusOr<std::unique_ptr<Literal>> Convert(
+ PrimitiveType primitive_dest_type) const;
+
+ // Clones the underlying buffers into a new Literal, or new
+ // std::unique_ptr<Literal>.
+ Literal Clone() const;
+ std::unique_ptr<Literal> CloneToUnique() const;
+
+ // TODO(b/67651157): The methods below which perform computation on Literals
+ // (Reshape, Slice, etc) should be moved elsewhere, and perhaps combined with
+ // evaluator code which operates on Literals.
+ //
+ // Creates a new value that has the equivalent value as this
+ // literal, but conforms to new_layout; e.g. a literal matrix that was in {0,
+ // 1} minor-to-major dimension layout can be re-layed-out as {1, 0}
+ // minor-to-major dimension layout and the value in the cell at any given
+ // logical index (i0, i1) will be the same.
+ //
+ // For tuple shaped literals, shape_index should be used to select the inner
+ // array that the new layout applies to.
+ //
+ // Note: this is useful when the client wants to ensure that a value placed in
+ // the XLA allocation tracker has a particular layout; for efficiency
+ // purposes or avoiding unimplemented operation/layout combinations.
+ std::unique_ptr<Literal> Relayout(const Layout& new_layout,
+ const ShapeIndex& shape_index = {}) const;
+
+ // An overload of Relayout which changes the layout of the entire shape rather
+ // than being limited to a single array within the shape.
+ std::unique_ptr<Literal> Relayout(const Shape& shape_with_layout) const;
+
+ // Creates a new literal by reshaping this literal to have the given
+ // dimensions. The total number of elements must not change; The
+ // implementation currently only supports monotonic dim0-major layouts.
+ // This literal must be an array.
+ StatusOr<std::unique_ptr<Literal>> Reshape(
+ tensorflow::gtl::ArraySlice<int64> dimensions) const;
+
+ // Creates a new literal by broadcasting this literal with `dimensions` to
+ // yield a literal of shape `result_shape`.
+ StatusOr<std::unique_ptr<Literal>> Broadcast(
+ const Shape& result_shape,
+ tensorflow::gtl::ArraySlice<int64> dimensions) const;
+
+ // Creates a new literal by reordering the dimensions of this literal.
+ // The given `permutation` must be a permutation of the dimension numbers
+ // in the original literal, and it specifies the order of the new dimensions
+ // in the result literal (i.e., new_order[i] = old_order[permutation[i]]).
+ // For example, a transpose call on a literal of shape [3 x 8 x 4] and
+ // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8].
+ // This literal must be an array.
+ std::unique_ptr<Literal> Transpose(
+ tensorflow::gtl::ArraySlice<int64> permutation) const;
+
+ // Creates a sub-array from this literal by extracting the indices
+ // [start_index, limit_index) of each dimension. The result literal has the
+ // same rank and layout as for the given literal. The number of indices in
+ // start_indices and limit_indices must be the rank of the literal, and the
+ // indices follow the order of the dimensions.
+ // This literal must be an array.
+ std::unique_ptr<Literal> Slice(
+ tensorflow::gtl::ArraySlice<int64> start_indices,
+ tensorflow::gtl::ArraySlice<int64> limit_indices) const;
+
+ // Creates a literal with a prepended dimension with bound "times"; e.g. a
+ // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this
+ // literal replicated four times.
+ // This literal must be an array.
+ template <typename NativeT>
+ std::unique_ptr<Literal> Replicate(int64 times) const;
+
+ // Creates a new Literal object with the shape specified as parameter.
+ // The content of the literal values is the default value of the primitive
+ // type of literal itself (0 for numeric types, and false for predicates).
+ //
+ // Note: It's an antipattern to use this method then immediately call
+ // Literal::Populate on the result (since that results in zero initialization,
+ // then reinitialization. Conside if a call to MakeUnique<Literal>(shape),
+ // followed by the call to Literal::Populate can be used instead.
+ static std::unique_ptr<Literal> CreateFromShape(const Shape& shape);
+
+ protected:
+ // A data structure representing a subshape at a particular ShapeIndex within
+ // the literal. For array-shaped ShapeIndexes, this data structure holds the
+ // pointer to the memory allocated for the array data.
+ class Piece {
+ public:
+ // Returns the buffer holding the array data for this piece as an array
+ // slice. This piece must be array-shaped.
+ template <typename NativeT>
+ tensorflow::gtl::ArraySlice<NativeT> data() const;
+ template <typename NativeT>
+ tensorflow::gtl::MutableArraySlice<NativeT> data();
+
+ // Returns the buffer holding the array data for this piece as a void*. This
+ // piece must be array-shaped.
+ void* untyped_data();
+ const void* untyped_data() const;
+
+ // Gets or sets an element in the array at the given index. The multi_index
+ // is CHECKed against the dimension sizes of the array. This piece must be
+ // array-shaped.
+ template <typename NativeT>
+ NativeT Get(tensorflow::gtl::ArraySlice<int64> index) const;
+ template <typename NativeT>
+ void Set(tensorflow::gtl::ArraySlice<int64> index, NativeT value);
+
+ // Gets/sets the buffer holding the array data.
+ char* buffer() const { return buffer_; }
+ void set_buffer(char* buffer) { buffer_ = buffer; }
+
+ // The array of multi-indices that provide the locations of non-zero
+ // elements in a sparse array. Only used if
+ // LayoutUtil::IsSparseArray(shape()) is true.
+ SparseIndexArray* sparse_indices() const { return sparse_indices_; }
+ void set_sparse_indices(SparseIndexArray* sparse_indices) {
+ sparse_indices_ = sparse_indices;
+ }
+
+ // Gets or sets the subshape of this piece. This reference points to a
+ // subshape within the shape in the containing Literal (Literal::shape_).
+ const Shape& subshape() const { return *subshape_; }
+ void set_subshape(const Shape* subshape) { subshape_ = subshape; }
+
+ // Returns the size in bytes of the buffer holding the array data.
+ int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); }
+
+ // Returns the number of elements in this piece's array.
+ int64 element_count() const {
+ // If this is a sparse array, use the number of elements represented by
+ // the indices in the associated SparseIndexArray.
+ return LayoutUtil::IsSparseArray(subshape())
+ ? sparse_indices()->index_count()
+ : ShapeUtil::ElementsIn(subshape());
+ }
+
+ // Returns the child piece at 'index' of this piece.
+ Piece& child(int64 index) { return children_[index]; }
+
+ // Adds a child piece to this piece's children.
+ void emplace_back(Piece child_piece) {
+ children_.emplace_back(std::move(child_piece));
+ }
+
+ // Returns the size of children pieces of this piece.
+ int64 children_size() { return children_.size(); }
+
+ // Visitor functions that recursively traverses the piece and calls the
+ // given function at each child piece. The function has the type:
+ // void (const ShapeIndex& index, const Piece& piece)
+ template <typename Fn>
+ void ForEachSubpiece(const Fn& func) const {
+ ShapeIndex index;
+ return ForEachHelper(
+ [&func](const ShapeIndex& index, const Piece& piece) {
+ func(index, piece);
+ return Status::OK();
+ },
+ *this, &index)
+ .IgnoreError();
+ }
+ // Same as above, but the function has the type:
+ // Status (const ShapeIndex& index, const Piece& piece)
+ // The first non-OK return value is returned by the function.
+ template <typename Fn>
+ Status ForEachSubpieceWithStatus(const Fn& func) const {
+ ShapeIndex index;
+ return ForEachHelper(func, *this, &index);
+ }
+ // Same as above, but the function has the type:
+ // Bool (const ShapeIndex& index, const Piece& piece)
+ // The first non-true return value is returned by the function.
+ template <typename Fn>
+ bool ForEachSubpieceWithBool(const Fn& func) const {
+ ShapeIndex index;
+ return ForEachHelperBool(func, *this, &index);
+ }
+ // Same as above, but the function has the type:
+ // Void (const ShapeIndex& index, Piece& piece)
+ template <typename Fn>
+ void ForEachMutableSubpiece(const Fn& func) {
+ ShapeIndex index;
+ return ForEachMutableHelper(
+ [&func](const ShapeIndex& index, Piece* piece) {
+ func(index, piece);
+ return Status::OK();
+ },
+ const_cast<xla::LiteralBase::Piece*>(this), &index)
+ .IgnoreError();
+ }
+ // Same as above, but the function has the type:
+ // Status (const ShapeIndex& index, Piece& piece)
+ // The first non-OK return value is returned by the function.
+ template <typename Fn>
+ Status ForEachMutableSubpieceWithStatus(const Fn& func) {
+ ShapeIndex index;
+ return ForEachMutableHelper(
+ func, const_cast<xla::LiteralBase::Piece*>(this), &index);
+ }
+
+ // Returns true if this piece and 'other' contain the same data. This piece
+ // and 'other' must be array-shaped and compatible.
+ bool EqualElements(const Piece& other) const;
+
+ // Writes the shape and data (if array-shaped) into the given proto.
+ void WriteToProto(LiteralProto* proto) const;
+
+ // Copy the data from 'src' into this piece's buffer. Shapes of this piece
+ // and src must be compatible.
+ Status CopyFrom(const Piece& src);
+
+ // Copies the data from the given proto into this piece. The shape of this
+ // piece must be equal (not just compatible) to the shape of the proto.
+ Status CopyFromProto(const LiteralProto& proto);
+
+ // Sorts the elements in a sparse array.
+ void SortSparseElements();
+
+ private:
+ // Helpers for traversing the piece via ForEachSubpiece rooted at 'index'.
+ // The first non-OK (or non-true) value is returned by the function.
+ // The callable 'func' has the same signature as described above in
+ // ForEachSubpiece*.
+ template <typename Fn>
+ Status ForEachHelper(const Fn& func, const Piece& piece,
+ ShapeIndex* index) const {
+ TF_RETURN_IF_ERROR(func(*index, piece));
+ for (int64 i = 0; i < piece.children_.size(); ++i) {
+ index->push_back(i);
+ TF_RETURN_IF_ERROR(ForEachHelper(func, piece.children_[i], index));
+ index->pop_back();
+ }
+ return Status::OK();
+ }
+ template <typename Fn>
+ bool ForEachHelperBool(const Fn& func, const Piece& piece,
+ ShapeIndex* index) const {
+ if (!func(*index, piece)) {
+ return false;
+ }
+ for (int64 i = 0; i < piece.children_.size(); ++i) {
+ index->push_back(i);
+ if (!ForEachHelperBool(func, piece.children_[i], index)) {
+ return false;
+ }
+ index->pop_back();
+ }
+ return true;
+ }
+ template <typename Fn>
+ Status ForEachMutableHelper(const Fn& func, Piece* piece,
+ ShapeIndex* index) {
+ TF_RETURN_IF_ERROR(func(*index, piece));
+ for (int64 i = 0; i < piece->children_.size(); ++i) {
+ index->push_back(i);
+ TF_RETURN_IF_ERROR(
+ ForEachMutableHelper(func, &piece->children_[i], index));
+ index->pop_back();
+ }
+ return Status::OK();
+ }
+
+ // Recursive helper for EqualElements.
+ template <typename NativeT>
+ bool EqualElementsInternal(const Piece& other,
+ std::vector<int64>* multi_index) const;
+
+ // Helper for SortSparseElements that has the element type as a template
+ // parameter.
+ template <typename NativeT>
+ void SortSparseElementsInternal();
+
+ // For array-shaped pieces, this is the buffer holding the literal data.
+ char* buffer_ = nullptr;
+
+ // For sparse arrays, this is the array of indices.
+ SparseIndexArray* sparse_indices_ = nullptr;
+
+ // The shape of piece. This points into the shape of the containing Literal
+ // (Literal::shape_).
+ const Shape* subshape_ = nullptr;
+
+ // Children pieces for tuple shaped pieces.
+ std::vector<Piece> children_ = {};
+ }; // class Piece
+
+ const Piece& piece(const ShapeIndex& shape_index) const {
+ Piece* piece = &const_cast<Piece&>(root_piece());
+ for (const auto i : shape_index) {
+ DCHECK_GE(i, 0);
+ DCHECK_LT(i, piece->children_size());
+ piece = &piece->child(i);
+ }
+ return *piece;
+ }
+
+ // Returns the piece at the root of the shape.
+ virtual const Piece& root_piece() const = 0;
+
+ // LiteralSlice and Literal must access Pieces of other Literals.
+ friend class Literal;
+ friend class LiteralSlice;
+ friend class BorrowingLiteral;
+
+ private:
+ template <typename NativeT>
+ std::unique_ptr<Literal> SliceInternal(
+ const Shape& result_shape,
+ tensorflow::gtl::ArraySlice<int64> start_indices) const;
+};
+
+// Class representing literal values in XLA.
+//
+// The underlying buffer and shape is always owned by this class.
+class Literal : public LiteralBase {
+ public:
+ Literal() : Literal(ShapeUtil::MakeNil()) {}
+
+ // Create a literal of the given shape. The literal is allocated sufficient
+ // memory to hold the shape. Memory is uninitialized.
+ explicit Literal(const Shape& shape);
+ virtual ~Literal();
+
+ // Literals are moveable, but not copyable. To copy a literal use
+ // Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies
+ // of literals which can be expensive.
+ Literal(const Literal& other) = delete;
+ Literal& operator=(const Literal& other) = delete;
+ Literal(Literal&& other);
+ // 'allocate_arrays' indicates whether to allocate memory for the arrays in
+ // the shape. If false, buffer pointers inside of the Literal::Pieces are set
+ // to nullptr.
+ Literal(const Shape& shape, bool allocate_arrays);
+ Literal& operator=(Literal&& other);
+
+ // TODO(b/67651157): Remove this accessor. Literal users should not be able to
+ // mutate the shape as this can produce malformed Literals.
+ Shape* mutable_shape_do_not_use() { return shape_.get(); }
+
+ // Returns a MutableArraySlice view of the array for this literal for the
+ // given NativeT (e.g., float). CHECKs if the subshape of the literal at the
+ // given ShapeIndex is not array. See primitive_util.h for the mapping from
+ // XLA type to native type.
+ template <typename NativeT>
+ tensorflow::gtl::MutableArraySlice<NativeT> data(
+ const ShapeIndex& shape_index = {});
+ // Unhide const method from parent class.
+ using LiteralBase::data;
+
+ // Returns a pointer to the sparse index array. Returns nullptr if the literal
+ // is not a sparse array.
+ SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {});
+
+ // Returns a pointer to the underlying buffer holding the array at the given
+ // shape index. CHECKs if the subshape of the literal at the given ShapeIndex
+ // is not array.
+ void* untyped_data(const ShapeIndex& shape_index = {});
+ // Unhide const method from parent class.
+ using LiteralBase::untyped_data;
+
+ // Populates a literal with a sparse layout with the given indices and values.
+ // Each index in the indices array is CHECKed against the dimensions in the
+ // literal's shape. If sort is true, then the indices and values will be
+ // sorted. If sort is false, then the indices and values are assumed to
+ // already be in sorted order. See CreateSparse for an example of how data
+ // are populated.
+ template <typename NativeT>
+ void PopulateSparse(SparseIndexArray indices,
+ tensorflow::gtl::ArraySlice<NativeT> values,
+ bool sort = true);
+
+ // Copy values from 'src_literal' rooted at 'src_shape_index' into this
+ // literal rooted at 'dest_shape_index'. The subshape of this literal rooted
+ // at 'dest_shape_index' must be compatible with the subshape of 'src_literal'
+ // rooted at 'src_shape_index', but need not be arrays.
+ Status CopyFrom(const LiteralSlice& src_literal,
+ const ShapeIndex& dest_shape_index = {},
+ const ShapeIndex& src_shape_index = {});
+
+ // Returns a vector containing the tuple elements of this Literal as separate
+ // Literals. This Literal must be tuple-shaped and can be a nested tuple. The
+ // elements are moved into the new Literals; no data is copied. Upon return
+ // this Literal is set to a nil shape (empty tuple)
+ std::vector<Literal> DecomposeTuple();
+
+ // Similar to CopyFrom, but with move semantincs. The subshape of this literal
+ // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal'
+ // (layouts and shapes must match), but need not be arrays. The memory
+ // allocated in this literal for the subshape at dest_shape_index is
+ // deallocated, and the respective buffers are replaced with those in
+ // src_literal. Upon return, src_literal is set to a nil shape (empty tuple).
+ Status MoveFrom(Literal&& src_literal,
+ const ShapeIndex& dest_shape_index = {});
+
+ // Copies the values from src_literal, starting at src_base shape indexes,
+ // to this literal, starting at dest_base, where the copy size in each
+ // dimension is specified by copy_size.
+ // The src_literal and this literal must have the same primitive type,
+ // src_base+copy_size must fit the source literal dimensions, as well as
+ // dest_base+copy_size must fit the destination literal dimensions.
+ // Note: if either src_literal or this literal contains dimensions with zero
+ // element, then copy_size must be 0 in these dimensions while the
+ // corresponding base indices being 0.
+ // This literal and 'src_literal' must be arrays.
+ Status CopySliceFrom(const LiteralSlice& src_literal,
+ tensorflow::gtl::ArraySlice<int64> src_base,
+ tensorflow::gtl::ArraySlice<int64> dest_base,
+ tensorflow::gtl::ArraySlice<int64> copy_size);
+
+ // Copies one element from src_literal[src_index] to (*this)[dest_index].
+ Status CopyElementFrom(const LiteralSlice& src_literal,
+ tensorflow::gtl::ArraySlice<int64> src_index,
+ tensorflow::gtl::ArraySlice<int64> dest_index);
+
+ // Sets an element in the literal at the given index. The multi_index is
+ // CHECKed against the dimension sizes.
+ template <typename NativeT>
+ void Set(tensorflow::gtl::ArraySlice<int64> multi_index,
+ const ShapeIndex& shape_index, NativeT value);
+ // Overloads of Set for array literals. CHECKs if the literal is not
+ // array-shaped and dense.
+ template <typename NativeT>
+ void Set(tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value);
+
+ // Appends the given element to the literal. If the elements are not appended
+ // in sorted order, then SortSparseElements should be called before calling
+ // other methods. This literal must have a sparse layout.
+ template <typename NativeT>
+ void AppendSparseElement(tensorflow::gtl::ArraySlice<int64> multi_index,
+ NativeT value, const ShapeIndex& shape_index = {});
+
+ // Sorts the elements in a sparse array.
+ void SortSparseElements(const ShapeIndex& shape_index = {});
+
+ // As Set(), but truncates `value` to the literal element type before storing.
+ // This literal must be an array.
+ Status SetIntegralAsS64(tensorflow::gtl::ArraySlice<int64> multi_index,
+ int64 value);
+
+ // Populate this literal with the given values. Examples:
+ //
+ // // Populate with floats.
+ // Array2D<float> float_values = ...
+ // literal.PopulateR2FromArray2D(values);
+ //
+ // // Populate with int32s.
+ // literal.PopulateR2<int32>({{1, 2}, {3, 4}});
+ //
+ // The shape and element type of this literal must match given values. For
+ // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2
+ // array of S32.
+ template <typename NativeT>
+ void PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values);
+ void PopulateR1(const tensorflow::core::Bitmap& values);
+ template <typename NativeT>
+ void PopulateR2(std::initializer_list<std::initializer_list<NativeT>> values);
+ template <typename NativeT>
+ void PopulateFromArray(const Array<NativeT>& values);
+ template <typename NativeT>
+ void PopulateR2FromArray2D(const Array2D<NativeT>& values);
+ template <typename NativeT>
+ void PopulateR3FromArray3D(const Array3D<NativeT>& values);
+ template <typename NativeT>
+ void PopulateR4FromArray4D(const Array4D<NativeT>& values);
+
+ // Populates literal values by calling the generator function for every cell
+ // in this literal object.
+ //
+ // generator must be a callable of the type
+ // NativeT(tensorflow::gtl::ArraySlice<int64> indexes) or compatible.
+ //
+ // This literal must have a dense layout.
+ template <typename NativeT, typename FnType>
+ Status Populate(const FnType& generator);
+
+ // A parallel version of Populate(). This can be used if the generator is
+ // thread-safe and the values for the shape's different elements are
+ // independent.
+ template <typename NativeT, typename FnType>
+ Status PopulateParallel(const FnType& generator);
+
+ // Fills this literal with the given value.
+ template <typename NativeT>
+ void PopulateWithValue(NativeT value);
+
+ // This operation is the inverse of DecomposeTuple. The given elements are
+ // moved into the tuple elements of a new tuple-shaped Literal which is
+ // returned. Upon return, each of the Literals in 'elements' is set to a nil
+ // shape (empty tuple).
+ static Literal MoveIntoTuple(
+ tensorflow::gtl::MutableArraySlice<Literal> elements);
+
+ // Serialize from a proto.
+ static StatusOr<std::unique_ptr<Literal>> CreateFromProto(
+ const LiteralProto& proto);
+
+ private:
+ // Recursively sets the subshapes and buffers of all subpieces rooted at
+ // 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in
+ // the shape.
+ void SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays);
+
+ // Returns the piece at the given ShapeIndex.
+ Piece& piece(const ShapeIndex& shape_index) {
+ return const_cast<Piece&>(LiteralBase::piece(shape_index));
+ }
+
+ Piece& root_piece() const override { return *root_piece_; };
+
+ // Internal template helper for the Literal::CopySliceFrom(), matching its
+ // arguments one by one.
+ template <typename NativeT>
+ Status CopySliceFromInternal(const LiteralBase& src_literal,
+ tensorflow::gtl::ArraySlice<int64> src_base,
+ tensorflow::gtl::ArraySlice<int64> dest_base,
+ tensorflow::gtl::ArraySlice<int64> copy_size);
+
+ // Utility structure which is used to create the optimal configuration for
+ // a ShapeUtil::ForEachIndex() scan across two literals.
+ struct StrideConfig {
+ StrideConfig(const Shape& source_shape, const Shape& dest_shape,
+ tensorflow::gtl::ArraySlice<int64> dimensions);
+
+ // The dimensions of the stride operation. Essentially every dimension
+ // will be iterated from base[i] to base[i]+dimensions[i], in step[i]
+ // steps.
+ tensorflow::gtl::ArraySlice<int64> dimensions;
+ DimensionVector base;
+ DimensionVector step;
+ int64 minor_dimension = 0;
+ // The size of the strides for source and destination. One of the two
+ // (the one looping through its most minor dimension) will be 1, while
+ // the other will be the stride size at the dimension matching the other
+ // shape most minor dimension being scanned.
+ int64 dest_stride = 1;
+ int64 source_stride = 1;
+ // The size of the inner loop on the most minor dimension.
+ int64 minor_loop_size = 1;
+ };
+
+ // Literal class always owns the shape. The parent class borrows this shape.
+ std::unique_ptr<Shape> shape_;
+
+ Piece* root_piece_ = nullptr;
+
+ // Implementation details shared between Populate() and PopulateParallel()
+ template <typename NativeT, typename FnType>
+ Status PopulateInternal(const FnType& generator, bool parallel);
+
+ // Deallocate the buffers held by this literal.
+ void DeallocateBuffers();
+
+ friend class LiteralBase;
+};
+std::ostream& operator<<(std::ostream& out, const Literal& literal);
+
+// A read-only view of a Literal. A LiteralSlice contains pointers to shape and
+// literal buffers always owned by others.
+class LiteralSlice : public LiteralBase {
+ public:
+ LiteralSlice() : LiteralBase() {}
+
+ // Implicit conversion constructors.
+ LiteralSlice(const LiteralBase& literal);
+ LiteralSlice(const LiteralBase& literal, const ShapeIndex& view_root);
+
+ private:
+ const Piece& root_piece() const override { return *root_piece_; };
+
+ const Piece* root_piece_; // Not owned.
+};
+
+// A read-only Literal where the underlying buffers are never owned by this
+// class.
+class BorrowingLiteral : public LiteralBase {
+ public:
+ BorrowingLiteral() : LiteralBase() {}
+
+ // 'src_buf_ptr' is not owned by this class and must outlive the
+ // lifetime of this class. It points to an appropirately sized buffer with
+ // data interpretered as indicated by 'shape'.
+ // This constructor is only used for array shapes.
+ BorrowingLiteral(const char* src_buf_ptr, const Shape& shape);
+ // Similar as above, except to be used for constructing non-nested tuples.
+ BorrowingLiteral(tensorflow::gtl::ArraySlice<const char*> src_buf_ptrs,
+ const Shape& shape);
+ // TODO(b/79707221): adding constructors for nested tuples as well.
+
+ private:
+ // Recursively builds the subtree for the given piece and sets the subshapes
+ // of the given piece with the given shape.
+ void BuildPieceSubtree(const Shape& shape, Piece* piece);
+
+ // Accessor for the root piece of this literal.
+ const Piece& root_piece() const override { return root_piece_; };
+ Piece root_piece_;
+
+ // Shape of this literal. Stored as unique_ptr so such that the (default)
+ // move construction of this class would be trivially correct: the pointer to
+ // Shape root_piece_ stores will still point to the correct address.
+ std::unique_ptr<Shape> shape_;
+};
+
+template <typename NativeT>
+tensorflow::gtl::ArraySlice<NativeT> LiteralBase::Piece::data() const {
+ CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
+ CHECK_EQ(subshape().element_type(),
+ primitive_util::NativeToPrimitiveType<NativeT>())
+ << "Attempting to access "
+ << PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
+ << " type, but literal element type is "
+ << PrimitiveType_Name(subshape().element_type());
+ return tensorflow::gtl::ArraySlice<NativeT>(
+ reinterpret_cast<const NativeT*>(buffer()), element_count());
+}
+
+template <typename NativeT>
+tensorflow::gtl::MutableArraySlice<NativeT> LiteralBase::Piece::data() {
+ CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
+ CHECK_EQ(subshape().element_type(),
+ primitive_util::NativeToPrimitiveType<NativeT>())
+ << "Attempting to access "
+ << PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
+ << " type, but literal element type is "
+ << PrimitiveType_Name(subshape().element_type());
+ return tensorflow::gtl::MutableArraySlice<NativeT>(
+ reinterpret_cast<NativeT*>(buffer()), element_count());
+}
+
+template <typename NativeT>
+NativeT LiteralBase::Piece::Get(
+ tensorflow::gtl::ArraySlice<int64> multi_index) const {
+ CHECK(LayoutUtil::IsDenseArray(subshape()));
+ return data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
+ subshape(), multi_index)];
+}
+
+template <typename NativeT>
+void LiteralBase::Piece::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
+ NativeT value) {
+ CHECK(LayoutUtil::IsDenseArray(subshape()));
+ data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
+ subshape(), multi_index)] = value;
+}
+
+template <typename NativeT>
+tensorflow::gtl::ArraySlice<NativeT> LiteralBase::data(
+ const ShapeIndex& shape_index) const {
+ return piece(shape_index).data<NativeT>();
+}
+
+template <typename NativeT>
+tensorflow::gtl::MutableArraySlice<NativeT> Literal::data(
+ const ShapeIndex& shape_index) {
+ return piece(shape_index).data<NativeT>();
+}
+
+template <typename NativeT>
+inline NativeT LiteralBase::Get(tensorflow::gtl::ArraySlice<int64> multi_index,
+ const ShapeIndex& shape_index) const {
+ return piece(shape_index).Get<NativeT>(multi_index);
+}
+
+template <typename NativeT>
+inline NativeT LiteralBase::Get(
+ tensorflow::gtl::ArraySlice<int64> multi_index) const {
+ return root_piece().Get<NativeT>(multi_index);
+}
+
+template <typename NativeT>
+inline void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
+ const ShapeIndex& shape_index, NativeT value) {
+ return piece(shape_index).Set<NativeT>(multi_index, value);
+}
+
+template <typename NativeT>
+inline void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
+ NativeT value) {
+ return root_piece().Set<NativeT>(multi_index, value);
+}
+
+template <typename NativeT>
+NativeT LiteralBase::GetFirstElement() const {
+ return data<NativeT>().at(0);
+}
+
+template <typename NativeT>
+NativeT LiteralBase::GetSparseElement(int64 sparse_element_number,
+ const ShapeIndex& shape_index) const {
+ CHECK(
+ LayoutUtil::IsSparseArray(ShapeUtil::GetSubshape(shape(), shape_index)));
+ return data<NativeT>(shape_index)[sparse_element_number];
+}
+
+template <typename NativeT>
+void Literal::AppendSparseElement(
+ tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value,
+ const ShapeIndex& shape_index) {
+ Piece& p = piece(shape_index);
+ const Shape& subshape = p.subshape();
+ CHECK(LayoutUtil::IsSparseArray(subshape));
+ int64 rank = ShapeUtil::Rank(subshape);
+ CHECK_EQ(multi_index.size(), rank);
+ int64 last_element = p.sparse_indices()->index_count();
+ CHECK_LT(last_element, LayoutUtil::MaxSparseElements(subshape.layout()));
+ p.sparse_indices()->Append(multi_index);
+ CHECK_LT(last_element, p.data<NativeT>().size());
+ p.data<NativeT>()[last_element] = value;
+}
+
+template <typename NativeT>
+void LiteralBase::EachCell(
+ std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
+ NativeT value)>
+ per_cell) const {
+ if (ShapeUtil::IsZeroElementArray(shape())) {
+ return;
+ }
+ std::vector<int64> indices(ShapeUtil::Rank(shape()), 0);
+ do {
+ per_cell(indices, Get<NativeT>(indices));
+ } while (IndexUtil::BumpIndices(shape(), &indices));
+}
+
+template <typename NativeT>
+inline void Literal::PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values) {
+ CHECK(ShapeUtil::IsArray(shape()));
+ CHECK_EQ(ShapeUtil::Rank(shape()), 1);
+ CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size());
+ CHECK_EQ(shape().element_type(),
+ primitive_util::NativeToPrimitiveType<NativeT>());
+ for (int64 i = 0; i < values.size(); ++i) {
+ Set({i}, values[i]);
+ }
+}
+
+template <typename NativeT>
+void Literal::PopulateR2(
+ std::initializer_list<std::initializer_list<NativeT>> values) {
+ CHECK(ShapeUtil::IsArray(shape()));
+ CHECK_EQ(ShapeUtil::Rank(shape()), 2);
+ CHECK_EQ(shape().element_type(),
+ primitive_util::NativeToPrimitiveType<NativeT>());
+
+ const int64 dim0_size = values.size();
+ const int64 dim1_size = values.begin()->size();
+ CHECK_EQ(dim0_size, shape().dimensions(0));
+ CHECK_EQ(dim1_size, shape().dimensions(1));
+
+ int64 dim0 = 0;
+ for (auto inner_list : values) {
+ int64 dim1 = 0;
+ for (auto value : inner_list) {
+ Set({dim0, dim1}, value);
+ ++dim1;
+ }
+ CHECK_EQ(dim1_size, dim1);
+ ++dim0;
+ }
+}
+
+template <typename NativeT>
+void Literal::PopulateFromArray(const Array<NativeT>& values) {
+ CHECK(ShapeUtil::IsArray(shape()));
+ CHECK_EQ(shape().element_type(),
+ primitive_util::NativeToPrimitiveType<NativeT>());
+ CHECK_EQ(ShapeUtil::Rank(shape()), values.num_dimensions());
+ for (int dim = 0; dim < values.num_dimensions(); ++dim) {
+ CHECK_EQ(values.dim(dim), shape().dimensions(dim));
+ }
+ values.Each([this](tensorflow::gtl::ArraySlice<int64> indices,
+ NativeT value) { this->Set(indices, value); });
+}
+
+template <typename NativeT>
+void Literal::PopulateR2FromArray2D(const Array2D<NativeT>& values) {
+ PopulateFromArray(values);
+}
+
+template <typename NativeT>
+void Literal::PopulateR3FromArray3D(const Array3D<NativeT>& values) {
+ PopulateFromArray(values);
+}
+
+template <typename NativeT>
+void Literal::PopulateR4FromArray4D(const Array4D<NativeT>& values) {
+ PopulateFromArray(values);
+}
+
+template <typename NativeT>
+void Literal::PopulateSparse(SparseIndexArray indices,
+ tensorflow::gtl::ArraySlice<NativeT> values,
+ bool sort) {
+ CHECK(LayoutUtil::IsSparseArray(shape()));
+ int rank = ShapeUtil::Rank(shape());
+ CHECK_EQ(indices.rank(), rank);
+ int64 max_elements = LayoutUtil::MaxSparseElements(shape().layout());
+ CHECK_LE(indices.max_indices(), max_elements);
+ int64 num_elements = values.size();
+ CHECK_LE(num_elements, max_elements);
+ CHECK_EQ(num_elements, indices.index_count());
+ auto root_data = root_piece().data<NativeT>();
+ // Piece::data() returns an ArraySlice of size equal to the number of indices
+ // in the SparseIndexArray. So there is no need to adjust the size of the data
+ // here. It is enough to just copy the incoming values into the data buffer.
+ std::copy(values.begin(), values.end(), root_data.begin());
+ *this->root_piece().sparse_indices() = std::move(indices);
+ if (sort) {
+ auto root_data = this->root_piece().data<NativeT>();
+ this->root_piece().sparse_indices()->SortWithValues(root_data);
+ }
+ DCHECK(this->root_piece().sparse_indices()->Validate(shape()));
+}
+
+template <typename NativeT, typename FnType>
+Status Literal::PopulateInternal(const FnType& generator, bool parallel) {
+ const Shape& this_shape = shape();
+ const int64 rank = ShapeUtil::Rank(this_shape);
+ TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape));
+ TF_RET_CHECK(this_shape.element_type() ==
+ primitive_util::NativeToPrimitiveType<NativeT>());
+ tensorflow::gtl::MutableArraySlice<NativeT> literal_data = data<NativeT>();
+ if (rank > 0) {
+ StrideConfig stride_config(this_shape, this_shape,
+ AsInt64Slice(this_shape.dimensions()));
+ int64 minor_dimension_size =
+ ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension);
+
+ auto init_function = [&](tensorflow::gtl::ArraySlice<int64> indexes) {
+ DimensionVector minor_scan_indexes(rank, 0);
+ const int64 index =
+ IndexUtil::MultidimensionalIndexToLinearIndex(shape(), indexes);
+ std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin());
+ for (int64 i = 0; i < minor_dimension_size; ++i) {
+ minor_scan_indexes[stride_config.minor_dimension] = i;
+ literal_data.at(index + i) = generator(minor_scan_indexes);
+ }
+ };
+ if (parallel) {
+ ShapeUtil::ForEachIndexParallel(this_shape, stride_config.base,
+ stride_config.dimensions,
+ stride_config.step, init_function);
+ } else {
+ ShapeUtil::ForEachIndex(
+ this_shape, stride_config.base, stride_config.dimensions,
+ stride_config.step,
+ [&init_function](tensorflow::gtl::ArraySlice<int64> indexes) {
+ init_function(indexes);
+ return true;
+ });
+ }
+ } else {
+ // For scalars.
+ literal_data.at(0) = generator({});
+ }
+ return Status::OK();
+}
+template <typename NativeT, typename FnType>
+Status Literal::Populate(const FnType& generator) {
+ return PopulateInternal<NativeT>(generator, /*parallel=*/false);
+}
+
+template <typename NativeT, typename FnType>
+Status Literal::PopulateParallel(const FnType& generator) {
+ return PopulateInternal<NativeT>(generator, /*parallel=*/true);
+}
+
+template <typename NativeT>
+void Literal::PopulateWithValue(NativeT value) {
+ CHECK(ShapeUtil::IsArray(shape()));
+ CHECK_EQ(shape().element_type(),
+ primitive_util::NativeToPrimitiveType<NativeT>());
+ for (NativeT& element : data<NativeT>()) {
+ element = value;
+ }
+}
+
+template <typename NativeT>
+std::unique_ptr<Literal> LiteralBase::Replicate(int64 times) const {
+ DimensionVector bounds = {times};
+ bounds.reserve(shape().dimensions_size() + 1);
+ for (int64 bound : shape().dimensions()) {
+ bounds.push_back(bound);
+ }
+ auto literal =
+ MakeUnique<Literal>(ShapeUtil::MakeShape(shape().element_type(), bounds));
+ int64 elements = ShapeUtil::ElementsIn(literal->shape());
+ if (elements == 0) {
+ return literal;
+ }
+
+ DimensionVector output_indices(bounds.size(), 0);
+ tensorflow::gtl::ArraySlice<int64> input_indices = output_indices;
+ input_indices.remove_prefix(1);
+
+ bool done = false;
+ while (!done) {
+ const auto element = Get<NativeT>(input_indices);
+ literal->Set<NativeT>(output_indices, element);
+
+ done = true;
+ for (int n = 0; n < output_indices.size(); ++n) {
+ ++output_indices[n];
+ if (output_indices[n] < bounds[n]) {
+ done = false;
+ break;
+ }
+ output_indices[n] = 0;
+ }
+ }
+ return literal;
+}
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_LITERAL_H_
diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc
index 2125ab7c61..94993cc874 100644
--- a/tensorflow/compiler/xla/literal_comparison.cc
+++ b/tensorflow/compiler/xla/literal_comparison.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <cmath>
#include <vector>
+#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/casts.h"
#include "tensorflow/core/lib/strings/strcat.h"
@@ -217,7 +218,7 @@ class NearComparator {
return Printf(
"actual %s, expected %s, index %s, rel error %8.3g, abs error %8.3g",
FpValueToString(actual).c_str(), FpValueToString(expected).c_str(),
- Literal::MultiIndexAsString(
+ LiteralUtil::MultiIndexAsString(
IndexUtil::LinearIndexToMultidimensionalIndex(shape,
linear_index))
.c_str(),
@@ -722,7 +723,7 @@ Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) {
return AppendStatus(result,
tensorflow::strings::Printf(
"\nat index: %s\nexpected: %s\nactual: %s",
- Literal::MultiIndexAsString(multi_index).c_str(),
+ LiteralUtil::MultiIndexAsString(multi_index).c_str(),
ToStringTruncated(expected).c_str(),
ToStringTruncated(actual).c_str()));
}
diff --git a/tensorflow/compiler/xla/literal_comparison.h b/tensorflow/compiler/xla/literal_comparison.h
index 00a13e3619..9e5bf7c1d0 100644
--- a/tensorflow/compiler/xla/literal_comparison.h
+++ b/tensorflow/compiler/xla/literal_comparison.h
@@ -20,7 +20,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_LITERAL_COMPARISON_H_
#include "tensorflow/compiler/xla/error_spec.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/lib/core/status.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_test.cc
index 493d807591..e8f919950f 100644
--- a/tensorflow/compiler/xla/literal_util_test.cc
+++ b/tensorflow/compiler/xla/literal_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include <vector>
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/types.h"
@@ -76,11 +77,11 @@ class LiteralUtilTest : public ::testing::Test {
layout_r4_dim0minor_ = LayoutUtil::MakeLayout({0, 1, 2, 3});
literal_r4_2x2x3x3_dim0major_ =
- Literal::CreateR4FromArray4DWithLayout<float>(arr4d,
- layout_r4_dim0major_);
+ LiteralUtil::CreateR4FromArray4DWithLayout<float>(arr4d,
+ layout_r4_dim0major_);
literal_r4_2x2x3x3_dim0minor_ =
- Literal::CreateR4FromArray4DWithLayout<float>(arr4d,
- layout_r4_dim0minor_);
+ LiteralUtil::CreateR4FromArray4DWithLayout<float>(arr4d,
+ layout_r4_dim0minor_);
}
Layout layout_r2_dim0major_;
@@ -94,47 +95,47 @@ class LiteralUtilTest : public ::testing::Test {
};
TEST_F(LiteralUtilTest, LiteralScalarToString) {
- auto true_lit = Literal::CreateR0<bool>(true);
+ auto true_lit = LiteralUtil::CreateR0<bool>(true);
ASSERT_EQ("true", true_lit->ToString());
- auto false_lit = Literal::CreateR0<bool>(false);
+ auto false_lit = LiteralUtil::CreateR0<bool>(false);
ASSERT_EQ("false", false_lit->ToString());
- auto u32_lit = Literal::CreateR0<uint32>(42);
+ auto u32_lit = LiteralUtil::CreateR0<uint32>(42);
ASSERT_EQ("42", u32_lit->ToString());
- auto s32_lit = Literal::CreateR0<int32>(-999);
+ auto s32_lit = LiteralUtil::CreateR0<int32>(-999);
ASSERT_EQ("-999", s32_lit->ToString());
- auto f32_lit = Literal::CreateR0<float>(3.14f);
+ auto f32_lit = LiteralUtil::CreateR0<float>(3.14f);
ASSERT_EQ("3.14", f32_lit->ToString());
- auto f16_lit = Literal::CreateR0<half>(static_cast<half>(0.5f));
+ auto f16_lit = LiteralUtil::CreateR0<half>(static_cast<half>(0.5f));
ASSERT_EQ("0.5", f16_lit->ToString());
- auto c64_lit = Literal::CreateR0<complex64>({3.14f, 2.78f});
+ auto c64_lit = LiteralUtil::CreateR0<complex64>({3.14f, 2.78f});
ASSERT_EQ("(3.14, 2.78)", c64_lit->ToString());
- auto bf16_lit = Literal::CreateR0<bfloat16>(static_cast<bfloat16>(0.5f));
+ auto bf16_lit = LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(0.5f));
ASSERT_EQ("0.5", bf16_lit->ToString());
// 3.14 will be truncated to 3.125 in bfloat16 format.
auto bf16_lit_truncated =
- Literal::CreateR0<bfloat16>(static_cast<bfloat16>(3.14f));
+ LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(3.14f));
ASSERT_EQ("3.125", bf16_lit_truncated->ToString());
auto bf16_lit_truncated2 =
- Literal::CreateR0<bfloat16>(static_cast<bfloat16>(9.001f));
+ LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(9.001f));
ASSERT_EQ("9", bf16_lit_truncated2->ToString());
}
TEST_F(LiteralUtilTest, LiteralVectorToString) {
- auto pred_vec = Literal::CreateR1<bool>({true, false, true});
+ auto pred_vec = LiteralUtil::CreateR1<bool>({true, false, true});
ASSERT_EQ("{101}", pred_vec->ToString());
}
TEST_F(LiteralUtilTest, R2ToString) {
- const auto literal = Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}});
+ const auto literal = LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}});
const string expected = R"(s32[3,2] {
{ 1, 2 },
{ 3, 4 },
@@ -144,7 +145,8 @@ TEST_F(LiteralUtilTest, R2ToString) {
}
TEST_F(LiteralUtilTest, R3ToString) {
- const auto literal = Literal::CreateR3({{{1}, {2}}, {{3}, {4}}, {{5}, {6}}});
+ const auto literal =
+ LiteralUtil::CreateR3({{{1}, {2}}, {{3}, {4}}, {{5}, {6}}});
const string expected = R"(s32[3,2,1] {
{ { 1 },
{ 2 } },
@@ -157,9 +159,9 @@ TEST_F(LiteralUtilTest, R3ToString) {
}
TEST_F(LiteralUtilTest, TupleToString) {
- auto scalar = Literal::CreateR0<float>(1.0);
- auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()});
+ auto scalar = LiteralUtil::CreateR0<float>(1.0);
+ auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
const string expected = R"((f32[], f32[2,2]) (
1,
f32[2,2] {
@@ -182,7 +184,7 @@ TEST_F(LiteralUtilTest, CreateR3FromArray3d) {
});
// clang-format on
- auto literal = Literal::CreateR3FromArray3D(array_3d);
+ auto literal = LiteralUtil::CreateR3FromArray3D(array_3d);
EXPECT_THAT(literal->shape().dimensions(), ElementsAre(2, 3, 2));
string result = literal->ToString();
const string expected = R"(f32[2,3,2] {
@@ -205,7 +207,7 @@ TEST_F(LiteralUtilTest, CreateSparse) {
{3, 5, 6},
};
std::vector<int64> values = {7, 8, 9, 10};
- auto literal = Literal::CreateSparse<int64>(
+ auto literal = LiteralUtil::CreateSparse<int64>(
dimensions, SparseIndexArray(indices.n1() + 3, indices), values);
Array2D<int64> expected_indices = {
@@ -224,7 +226,7 @@ TEST_F(LiteralUtilTest, CreateSparse) {
TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) {
// clang-format off
- auto literal = Literal::CreateR4Projected<float>({
+ auto literal = LiteralUtil::CreateR4Projected<float>({
{1, 2},
{1001, 1002},
{2001, 2002},
@@ -284,7 +286,7 @@ TEST_F(LiteralUtilTest, LiteralR4F32Stringifies) {
TEST_F(LiteralUtilTest, EachCellR2F32) {
// clang-format off
- auto literal = Literal::CreateR2<float>({
+ auto literal = LiteralUtil::CreateR2<float>({
{3.1f, 4.2f},
{9.3f, 12.4f},
});
@@ -303,26 +305,27 @@ TEST_F(LiteralUtilTest, EachCellR2F32) {
TEST_F(LiteralUtilTest, ScalarEquality) {
// Test equality with scalars.
- auto f32_42 = Literal::CreateR0<float>(42.0);
- auto f32_42_clone = Literal::CreateR0<float>(42.0);
+ auto f32_42 = LiteralUtil::CreateR0<float>(42.0);
+ auto f32_42_clone = LiteralUtil::CreateR0<float>(42.0);
EXPECT_EQ(*f32_42, *f32_42);
EXPECT_EQ(*f32_42, *f32_42_clone);
- auto f32_123 = Literal::CreateR0<float>(123.0);
+ auto f32_123 = LiteralUtil::CreateR0<float>(123.0);
EXPECT_NE(*f32_42, *f32_123);
- auto f64_42 = Literal::CreateR0<double>(42.0);
+ auto f64_42 = LiteralUtil::CreateR0<double>(42.0);
EXPECT_NE(*f32_42, *f64_42);
}
TEST_F(LiteralUtilTest, NonScalarEquality) {
// Test equality with nonscalars.
- auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- auto matrix_clone = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- auto matrix_different = Literal::CreateR2<float>({{4.0, 3.0}, {1.0, 2.0}});
- auto vector_literal = Literal::CreateR1<float>({1.0, 2.0, 3.0, 4.0});
- auto scalar = Literal::CreateR0<float>(1.0);
+ auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ auto matrix_clone = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ auto matrix_different =
+ LiteralUtil::CreateR2<float>({{4.0, 3.0}, {1.0, 2.0}});
+ auto vector_literal = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, 4.0});
+ auto scalar = LiteralUtil::CreateR0<float>(1.0);
Literal nil(ShapeUtil::MakeNil());
EXPECT_EQ(*matrix, *matrix);
@@ -335,19 +338,19 @@ TEST_F(LiteralUtilTest, NonScalarEquality) {
}
TEST_F(LiteralUtilTest, TokenEquality) {
- auto token0 = Literal::CreateToken();
- auto token1 = Literal::CreateToken();
- auto scalar = Literal::CreateR0<float>(1.0);
+ auto token0 = LiteralUtil::CreateToken();
+ auto token1 = LiteralUtil::CreateToken();
+ auto scalar = LiteralUtil::CreateR0<float>(1.0);
EXPECT_EQ(*token0, *token1);
EXPECT_NE(*token0, *scalar);
- EXPECT_EQ(*Literal::MakeTuple({token0.get()}),
- *Literal::MakeTuple({token0.get()}));
- EXPECT_EQ(*Literal::MakeTuple({token0.get(), scalar.get()}),
- *Literal::MakeTuple({token1.get(), scalar.get()}));
- EXPECT_NE(*Literal::MakeTuple({token0.get(), scalar.get()}),
- *Literal::MakeTuple({scalar.get(), token1.get()}));
+ EXPECT_EQ(*LiteralUtil::MakeTuple({token0.get()}),
+ *LiteralUtil::MakeTuple({token0.get()}));
+ EXPECT_EQ(*LiteralUtil::MakeTuple({token0.get(), scalar.get()}),
+ *LiteralUtil::MakeTuple({token1.get(), scalar.get()}));
+ EXPECT_NE(*LiteralUtil::MakeTuple({token0.get(), scalar.get()}),
+ *LiteralUtil::MakeTuple({scalar.get(), token1.get()}));
}
TEST_F(LiteralUtilTest, DifferentLayoutEquality) {
@@ -371,43 +374,46 @@ TEST_F(LiteralUtilTest, DifferentLayoutEquality) {
TEST_F(LiteralUtilTest, TupleEquality) {
// Test equality with tuples.
- auto scalar = Literal::CreateR0<float>(1.0);
- auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- auto tuple1 = Literal::MakeTuple({scalar.get(), matrix.get()});
+ auto scalar = LiteralUtil::CreateR0<float>(1.0);
+ auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ auto tuple1 = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
// Tuple with the same elements. One element is shared with the original
// tuple, the other is a clone of the element in the original tuple.
- auto scalar_clone = Literal::CreateR0<float>(1.0);
- auto tuple2 = Literal::MakeTuple({scalar_clone.get(), matrix.get()});
+ auto scalar_clone = LiteralUtil::CreateR0<float>(1.0);
+ auto tuple2 = LiteralUtil::MakeTuple({scalar_clone.get(), matrix.get()});
EXPECT_EQ(*tuple1, *tuple2);
// Tuple with elements reversed.
- auto reversed_tuple = Literal::MakeTuple({matrix.get(), scalar.get()});
+ auto reversed_tuple = LiteralUtil::MakeTuple({matrix.get(), scalar.get()});
EXPECT_NE(*tuple1, *reversed_tuple);
// Tuple with different value.
- auto scalar_42 = Literal::CreateR0<float>(42.0);
- auto different_tuple = Literal::MakeTuple({scalar_42.get(), matrix.get()});
+ auto scalar_42 = LiteralUtil::CreateR0<float>(42.0);
+ auto different_tuple =
+ LiteralUtil::MakeTuple({scalar_42.get(), matrix.get()});
EXPECT_NE(*tuple1, *different_tuple);
}
TEST_F(LiteralUtilTest, C64Equality) {
// Test equality with tuples.
- auto vector = Literal::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
+ auto vector = LiteralUtil::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
// Tuple with the same elements. One element is shared with the original
// tuple, the other is a clone of the element in the original tuple.
- auto vector_clone = Literal::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
+ auto vector_clone =
+ LiteralUtil::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
EXPECT_EQ(*vector, *vector_clone);
- auto vector_reversed = Literal::CreateR1<complex64>({{3.0, 4.0}, {1.0, 2.0}});
+ auto vector_reversed =
+ LiteralUtil::CreateR1<complex64>({{3.0, 4.0}, {1.0, 2.0}});
EXPECT_NE(*vector, *vector_reversed);
}
TEST_F(LiteralUtilTest, IsAllTuple) {
- auto element1 = Literal::CreateR0<float>(0.0);
- auto element2 = Literal::CreateR2<float>({{0.0, 0.0}, {0.0, 0.0}});
- auto tuple = Literal::MakeTuple({element1.get(), element1.get()});
+ auto element1 = LiteralUtil::CreateR0<float>(0.0);
+ auto element2 = LiteralUtil::CreateR2<float>({{0.0, 0.0}, {0.0, 0.0}});
+ auto tuple = LiteralUtil::MakeTuple({element1.get(), element1.get()});
// Tuples should always return false for IsAll.
EXPECT_FALSE(tuple->IsAll(0));
@@ -416,140 +422,141 @@ TEST_F(LiteralUtilTest, IsAllTuple) {
// Verifies that CreateFromShape works for tuples.
TEST_F(LiteralUtilTest, CreateFromShapeTuple) {
- auto scalar = Literal::CreateR0<float>(0.0);
- auto matrix = Literal::CreateR2<int32>({{0, 0}, {0, 0}});
- auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()});
+ auto scalar = LiteralUtil::CreateR0<float>(0.0);
+ auto matrix = LiteralUtil::CreateR2<int32>({{0, 0}, {0, 0}});
+ auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
auto x = Literal::CreateFromShape(tuple->shape());
EXPECT_EQ(*tuple, *x);
}
TEST_F(LiteralUtilTest, IsAll) {
- EXPECT_TRUE(Literal::CreateR0<bool>(false)->IsAll(0));
- EXPECT_TRUE(Literal::CreateR0<bool>(true)->IsAll(1));
- EXPECT_FALSE(Literal::CreateR0<bool>(false)->IsAll(1));
- EXPECT_FALSE(Literal::CreateR0<bool>(false)->IsAll(2));
- EXPECT_FALSE(Literal::CreateR0<bool>(true)->IsAll(0));
- EXPECT_FALSE(Literal::CreateR0<bool>(true)->IsAll(2));
- EXPECT_FALSE(Literal::CreateR0<bool>(true)->IsAll(-1));
+ EXPECT_TRUE(LiteralUtil::CreateR0<bool>(false)->IsAll(0));
+ EXPECT_TRUE(LiteralUtil::CreateR0<bool>(true)->IsAll(1));
+ EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false)->IsAll(1));
+ EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false)->IsAll(2));
+ EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true)->IsAll(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true)->IsAll(2));
+ EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true)->IsAll(-1));
// We shouldn't reinterpret int8_min as an unsigned type and then decide that
// it is equal to 255.
auto int8_min = std::numeric_limits<int8>::min();
- EXPECT_FALSE(Literal::CreateR0<uint8>(255)->IsAll(int8_min));
+ EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(255)->IsAll(int8_min));
- EXPECT_TRUE(Literal::CreateR0<float>(42.0)->IsAll(42));
- EXPECT_FALSE(Literal::CreateR0<float>(42.0001)->IsAll(42));
+ EXPECT_TRUE(LiteralUtil::CreateR0<float>(42.0)->IsAll(42));
+ EXPECT_FALSE(LiteralUtil::CreateR0<float>(42.0001)->IsAll(42));
- EXPECT_TRUE(Literal::CreateR1<int>({100, 100, 100})->IsAll(100));
- EXPECT_FALSE(Literal::CreateR1<double>({100, 100, 100.001})->IsAll(100));
+ EXPECT_TRUE(LiteralUtil::CreateR1<int>({100, 100, 100})->IsAll(100));
+ EXPECT_FALSE(LiteralUtil::CreateR1<double>({100, 100, 100.001})->IsAll(100));
- EXPECT_TRUE(Literal::CreateR2<uint64>({{8, 8}, {8, 8}})->IsAll(8));
- EXPECT_FALSE(Literal::CreateR2<uint64>({{8, 8}, {8, 9}})->IsAll(8));
- EXPECT_FALSE(Literal::CreateR2<uint64>({{9, 8}, {8, 8}})->IsAll(8));
+ EXPECT_TRUE(LiteralUtil::CreateR2<uint64>({{8, 8}, {8, 8}})->IsAll(8));
+ EXPECT_FALSE(LiteralUtil::CreateR2<uint64>({{8, 8}, {8, 9}})->IsAll(8));
+ EXPECT_FALSE(LiteralUtil::CreateR2<uint64>({{9, 8}, {8, 8}})->IsAll(8));
half h8(8.0f);
half h9(9.0f);
- EXPECT_TRUE(Literal::CreateR2<half>({{h8}, {h8}})->IsAll(8));
- EXPECT_FALSE(Literal::CreateR2<half>({{h8}, {h9}})->IsAll(8));
- EXPECT_FALSE(Literal::CreateR2<half>({{h9}, {h8}})->IsAll(8));
+ EXPECT_TRUE(LiteralUtil::CreateR2<half>({{h8}, {h8}})->IsAll(8));
+ EXPECT_FALSE(LiteralUtil::CreateR2<half>({{h8}, {h9}})->IsAll(8));
+ EXPECT_FALSE(LiteralUtil::CreateR2<half>({{h9}, {h8}})->IsAll(8));
bfloat16 b8(8.0f);
bfloat16 b9(9.0f);
- EXPECT_TRUE(Literal::CreateR2<bfloat16>({{b8}, {b8}})->IsAll(8));
- EXPECT_FALSE(Literal::CreateR2<bfloat16>({{b8}, {b9}})->IsAll(8));
- EXPECT_FALSE(Literal::CreateR2<bfloat16>({{b9}, {b8}})->IsAll(8));
+ EXPECT_TRUE(LiteralUtil::CreateR2<bfloat16>({{b8}, {b8}})->IsAll(8));
+ EXPECT_FALSE(LiteralUtil::CreateR2<bfloat16>({{b8}, {b9}})->IsAll(8));
+ EXPECT_FALSE(LiteralUtil::CreateR2<bfloat16>({{b9}, {b8}})->IsAll(8));
// 9.001 will be truncated to 9.0
bfloat16 b91(9.001f);
bfloat16 b90(9.00f);
- EXPECT_TRUE(Literal::CreateR2<bfloat16>({{b91}, {b90}})->IsAll(9.0));
+ EXPECT_TRUE(LiteralUtil::CreateR2<bfloat16>({{b91}, {b90}})->IsAll(9.0));
complex64 c8_9 = {8, 9};
- EXPECT_FALSE(Literal::CreateR2<complex64>({{c8_9}, {c8_9}})->IsAll(8));
+ EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}})->IsAll(8));
auto uint64_max = std::numeric_limits<uint64>::max();
- EXPECT_FALSE(Literal::CreateR2<uint64>(
+ EXPECT_FALSE(LiteralUtil::CreateR2<uint64>(
{{uint64_max, uint64_max}, {uint64_max, uint64_max}})
->IsAll(-1));
}
TEST_F(LiteralUtilTest, IsAllFloat) {
// IsAllFloat always returns false when the literal is not floating-point.
- EXPECT_FALSE(Literal::CreateR0<bool>(false)->IsAllFloat(0));
- EXPECT_FALSE(Literal::CreateR0<int8>(0)->IsAllFloat(0));
- EXPECT_FALSE(Literal::CreateR0<uint8>(0)->IsAllFloat(0));
- EXPECT_FALSE(Literal::CreateR0<int>(0)->IsAllFloat(0));
-
- EXPECT_TRUE(Literal::CreateR0<float>(0)->IsAllFloat(0));
- EXPECT_TRUE(Literal::CreateR0<float>(.5)->IsAllFloat(.5));
- EXPECT_TRUE(Literal::CreateR0<float>(-.5)->IsAllFloat(-.5));
- EXPECT_FALSE(Literal::CreateR0<float>(-.5)->IsAllFloat(-.49));
+ EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false)->IsAllFloat(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<int8>(0)->IsAllFloat(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(0)->IsAllFloat(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<int>(0)->IsAllFloat(0));
+
+ EXPECT_TRUE(LiteralUtil::CreateR0<float>(0)->IsAllFloat(0));
+ EXPECT_TRUE(LiteralUtil::CreateR0<float>(.5)->IsAllFloat(.5));
+ EXPECT_TRUE(LiteralUtil::CreateR0<float>(-.5)->IsAllFloat(-.5));
+ EXPECT_FALSE(LiteralUtil::CreateR0<float>(-.5)->IsAllFloat(-.49));
EXPECT_FALSE(
- Literal::CreateR2<float>({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0));
- EXPECT_TRUE(
- Literal::CreateR2<float>({{.5, .5, .5}, {.5, .5, .5}})->IsAllFloat(.5));
-
- EXPECT_TRUE(Literal::CreateR0<double>(0)->IsAllFloat(0));
- EXPECT_TRUE(Literal::CreateR0<double>(.5)->IsAllFloat(.5));
- EXPECT_TRUE(Literal::CreateR0<double>(-.5)->IsAllFloat(-.5));
- EXPECT_FALSE(Literal::CreateR0<double>(-.5)->IsAllFloat(-.49));
+ LiteralUtil::CreateR2<float>({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0));
+ EXPECT_TRUE(LiteralUtil::CreateR2<float>({{.5, .5, .5}, {.5, .5, .5}})
+ ->IsAllFloat(.5));
+
+ EXPECT_TRUE(LiteralUtil::CreateR0<double>(0)->IsAllFloat(0));
+ EXPECT_TRUE(LiteralUtil::CreateR0<double>(.5)->IsAllFloat(.5));
+ EXPECT_TRUE(LiteralUtil::CreateR0<double>(-.5)->IsAllFloat(-.5));
+ EXPECT_FALSE(LiteralUtil::CreateR0<double>(-.5)->IsAllFloat(-.49));
EXPECT_FALSE(
- Literal::CreateR2<double>({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0));
+ LiteralUtil::CreateR2<double>({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0));
}
TEST_F(LiteralUtilTest, IsAllComplex) {
// IsAllComplex always returns false when the literal is not complex.
- EXPECT_FALSE(Literal::CreateR0<bool>(false)->IsAllComplex(0));
- EXPECT_FALSE(Literal::CreateR0<int8>(0)->IsAllComplex(0));
- EXPECT_FALSE(Literal::CreateR0<uint8>(0)->IsAllComplex(0));
- EXPECT_FALSE(Literal::CreateR0<int>(0)->IsAllComplex(0));
- EXPECT_FALSE(Literal::CreateR0<float>(0)->IsAllComplex(0));
- EXPECT_FALSE(Literal::CreateR0<double>(0)->IsAllComplex(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false)->IsAllComplex(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<int8>(0)->IsAllComplex(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(0)->IsAllComplex(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<int>(0)->IsAllComplex(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<float>(0)->IsAllComplex(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<double>(0)->IsAllComplex(0));
complex64 c8_9 = {8, 9};
complex64 c7_9 = {7, 9};
- EXPECT_TRUE(Literal::CreateR2<complex64>({{c8_9}, {c8_9}})
+ EXPECT_TRUE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}})
->IsAllComplex({8.0f, 9.0f}));
- EXPECT_FALSE(Literal::CreateR2<complex64>({{c7_9}, {c8_9}})
+ EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c7_9}, {c8_9}})
->IsAllComplex({8.0f, 9.0f}));
- EXPECT_FALSE(Literal::CreateR2<complex64>({{c8_9}, {c7_9}})
+ EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c7_9}})
->IsAllComplex({8.0f, 9.0f}));
}
TEST_F(LiteralUtilTest, IsAllFirst) {
// IsAllComplex always returns false when the literal is not complex.
- EXPECT_FALSE(Literal::CreateR1<bool>({false, true})->IsAllFirst());
- EXPECT_TRUE(Literal::CreateR1<bool>({false, false})->IsAllFirst());
- EXPECT_FALSE(Literal::CreateR1<int8>({1, 1, 2})->IsAllFirst());
- EXPECT_TRUE(Literal::CreateR1<int8>({5, 5, 5, 5})->IsAllFirst());
- EXPECT_FALSE(Literal::CreateR1<uint8>({1, 1, 2})->IsAllFirst());
- EXPECT_TRUE(Literal::CreateR1<int32>({5, 5, 5, 5})->IsAllFirst());
- EXPECT_FALSE(Literal::CreateR1<int32>({1, 1, 2})->IsAllFirst());
- EXPECT_TRUE(Literal::CreateR1<uint32>({5, 5, 5, 5})->IsAllFirst());
- EXPECT_FALSE(Literal::CreateR1<uint32>({1, 1, 2})->IsAllFirst());
+ EXPECT_FALSE(LiteralUtil::CreateR1<bool>({false, true})->IsAllFirst());
+ EXPECT_TRUE(LiteralUtil::CreateR1<bool>({false, false})->IsAllFirst());
+ EXPECT_FALSE(LiteralUtil::CreateR1<int8>({1, 1, 2})->IsAllFirst());
+ EXPECT_TRUE(LiteralUtil::CreateR1<int8>({5, 5, 5, 5})->IsAllFirst());
+ EXPECT_FALSE(LiteralUtil::CreateR1<uint8>({1, 1, 2})->IsAllFirst());
+ EXPECT_TRUE(LiteralUtil::CreateR1<int32>({5, 5, 5, 5})->IsAllFirst());
+ EXPECT_FALSE(LiteralUtil::CreateR1<int32>({1, 1, 2})->IsAllFirst());
+ EXPECT_TRUE(LiteralUtil::CreateR1<uint32>({5, 5, 5, 5})->IsAllFirst());
+ EXPECT_FALSE(LiteralUtil::CreateR1<uint32>({1, 1, 2})->IsAllFirst());
complex64 c8_9 = {8, 9};
complex64 c7_9 = {7, 9};
- EXPECT_TRUE(Literal::CreateR2<complex64>({{c8_9}, {c8_9}})->IsAllFirst());
- EXPECT_FALSE(Literal::CreateR2<complex64>({{c7_9}, {c8_9}})->IsAllFirst());
+ EXPECT_TRUE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}})->IsAllFirst());
+ EXPECT_FALSE(
+ LiteralUtil::CreateR2<complex64>({{c7_9}, {c8_9}})->IsAllFirst());
}
TEST_F(LiteralUtilTest, IsZero) {
- auto scalar_zero = Literal::CreateR0<float>(0.0f);
- auto scalar_one = Literal::CreateR0<float>(1.0f);
+ auto scalar_zero = LiteralUtil::CreateR0<float>(0.0f);
+ auto scalar_one = LiteralUtil::CreateR0<float>(1.0f);
EXPECT_TRUE(scalar_zero->IsZero({}));
EXPECT_FALSE(scalar_one->IsZero({}));
- auto array = Literal::CreateR2<uint32>({{1, 2, 0, 3}, {1, 0, 1, 2}});
+ auto array = LiteralUtil::CreateR2<uint32>({{1, 2, 0, 3}, {1, 0, 1, 2}});
EXPECT_FALSE(array->IsZero({0, 1}));
EXPECT_TRUE(array->IsZero({0, 2}));
EXPECT_TRUE(array->IsZero({1, 1}));
EXPECT_FALSE(array->IsZero({1, 2}));
- auto complex_zero = Literal::CreateR0<complex64>(0.0f);
- auto complex_nonzero = Literal::CreateR0<complex64>(0.5f);
+ auto complex_zero = LiteralUtil::CreateR0<complex64>(0.0f);
+ auto complex_nonzero = LiteralUtil::CreateR0<complex64>(0.5f);
EXPECT_TRUE(complex_zero->IsZero({}));
EXPECT_FALSE(complex_nonzero->IsZero({}));
}
@@ -563,7 +570,7 @@ TYPED_TEST_CASE(LiteralUtilTestTemplated, TestedTypes);
TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) {
// Make a non-integer for floating point types.
TypeParam half = TypeParam(1) / TypeParam(2);
- auto data = Literal::CreateR2<TypeParam>({{half, 2}, {3, 4}});
+ auto data = LiteralUtil::CreateR2<TypeParam>({{half, 2}, {3, 4}});
const Layout layout01 = LayoutUtil::MakeLayout({0, 1});
const Layout layout10 = LayoutUtil::MakeLayout({1, 0});
@@ -577,7 +584,7 @@ TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) {
}
TEST_F(LiteralUtilTest, ReshapeR0) {
- auto original = Literal::CreateR0<float>(1.7f);
+ auto original = LiteralUtil::CreateR0<float>(1.7f);
auto reshape = original->Reshape(/*dimensions=*/{}).ConsumeValueOrDie();
EXPECT_EQ(*original, *reshape);
}
@@ -585,13 +592,13 @@ TEST_F(LiteralUtilTest, ReshapeR0) {
TEST_F(LiteralUtilTest, ReshapeR4) {
// clang-format off
// F32[1x3x2x4]
- auto original = Literal::CreateR4WithLayout<float>({{
+ auto original = LiteralUtil::CreateR4WithLayout<float>({{
{{10, 11, 12, 13}, {14, 15, 16, 17}},
{{18, 19, 20, 21}, {22, 23, 24, 25}},
{{26, 27, 28, 29}, {30, 31, 32, 33}},
}}, layout_r4_dim0major_);
// F32[1x3x4x2]
- auto expected = Literal::CreateR3WithLayout<float>({
+ auto expected = LiteralUtil::CreateR3WithLayout<float>({
{{10, 11}, {12, 13}, {14, 15}, {16, 17}},
{{18, 19}, {20, 21}, {22, 23}, {24, 25}},
{{26, 27}, {28, 29}, {30, 31}, {32, 33}},
@@ -605,13 +612,13 @@ TEST_F(LiteralUtilTest, ReshapeR4) {
TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) {
// clang-format off
// F32[1x3x2x4]
- auto original = Literal::CreateR4WithLayout<float>({{
+ auto original = LiteralUtil::CreateR4WithLayout<float>({{
{{10, 11, 12, 13}, {14, 15, 16, 17}},
{{18, 19, 20, 21}, {22, 23, 24, 25}},
{{26, 27, 28, 29}, {30, 31, 32, 33}},
}}, layout_r4_dim0minor_);
// F32[1x3x4x2]
- auto expected = Literal::CreateR3WithLayout<float>({
+ auto expected = LiteralUtil::CreateR3WithLayout<float>({
{{10, 11}, {12, 13}, {14, 15}, {16, 17}},
{{18, 19}, {20, 21}, {22, 23}, {24, 25}},
{{26, 27}, {28, 29}, {30, 31}, {32, 33}},
@@ -623,7 +630,7 @@ TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) {
}
TEST_F(LiteralUtilTest, TransposeR0) {
- auto original = Literal::CreateR0<float>(1.7f);
+ auto original = LiteralUtil::CreateR0<float>(1.7f);
auto reshape = original->Transpose(/*permutation=*/{});
EXPECT_EQ(*original, *reshape);
}
@@ -631,7 +638,7 @@ TEST_F(LiteralUtilTest, TransposeR0) {
TEST_F(LiteralUtilTest, TransposeR4) {
// clang-format off
// F32[1x3x2x4]
- auto original = Literal::CreateR4<float>({{
+ auto original = LiteralUtil::CreateR4<float>({{
{{10, 11, 12, 13}, {14, 15, 16, 17}},
{{18, 19, 20, 21}, {22, 23, 24, 25}},
{{26, 27, 28, 29}, {30, 31, 32, 33}},
@@ -659,7 +666,7 @@ TEST_F(LiteralUtilTest, TestR4RelayoutEquivalence) {
TEST_F(LiteralUtilTest, TestR2LinearLayout) {
// Test expected memory layout of R2 dim0-minor (column-major) literal.
- auto mat_dim0minor = Literal::CreateR2WithLayout<int32>(
+ auto mat_dim0minor = LiteralUtil::CreateR2WithLayout<int32>(
{{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0minor_);
EXPECT_EQ(mat_dim0minor->element_count(), 6);
EXPECT_THAT(mat_dim0minor->data<int32>(), ElementsAre(1, 4, 2, 5, 3, 6));
@@ -670,7 +677,7 @@ TEST_F(LiteralUtilTest, TestR2LinearLayout) {
ElementsAre(1, 2, 3, 4, 5, 6));
// Test expected memory layout of R2 created with dim0-major (row-major).
- auto mat_dim0major = Literal::CreateR2WithLayout<int32>(
+ auto mat_dim0major = LiteralUtil::CreateR2WithLayout<int32>(
{{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0major_);
EXPECT_EQ(mat_dim0major->element_count(), 6);
EXPECT_THAT(mat_dim0major->data<int32>(), ElementsAre(1, 2, 3, 4, 5, 6));
@@ -695,8 +702,8 @@ TEST_F(LiteralUtilTest, TestR3LinearLayout) {
{10, 11, 12},
},
}); // clang-format on
- auto lit_dim0minor =
- Literal::CreateR3FromArray3DWithLayout<int>(arr3d, layout_r3_dim0minor_);
+ auto lit_dim0minor = LiteralUtil::CreateR3FromArray3DWithLayout<int>(
+ arr3d, layout_r3_dim0minor_);
EXPECT_EQ(lit_dim0minor->element_count(), 12);
std::vector<int> expected_dim0minor{1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12};
@@ -710,8 +717,8 @@ TEST_F(LiteralUtilTest, TestR3LinearLayout) {
testing::ElementsAreArray(expected_dim0major));
// Test expected memory layout of R3 created with dim0-major (row-major).
- auto lit_dim0major =
- Literal::CreateR3FromArray3DWithLayout<int>(arr3d, layout_r3_dim0major_);
+ auto lit_dim0major = LiteralUtil::CreateR3FromArray3DWithLayout<int>(
+ arr3d, layout_r3_dim0major_);
EXPECT_EQ(lit_dim0major->element_count(), 12);
EXPECT_THAT(lit_dim0major->data<int32>(),
testing::ElementsAreArray(expected_dim0major));
@@ -723,28 +730,28 @@ TEST_F(LiteralUtilTest, TestR3LinearLayout) {
}
TEST_F(LiteralUtilTest, SliceR0S32) {
- auto input = Literal::CreateR0<int32>(1);
+ auto input = LiteralUtil::CreateR0<int32>(1);
auto result = input->Slice({}, {});
EXPECT_EQ(*input, *result);
}
TEST_F(LiteralUtilTest, SliceR1F32) {
- auto input = Literal::CreateR1<float>({1.0, 2.0, 3.0, 4.0, 5.0});
+ auto input = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, 4.0, 5.0});
auto result = input->Slice({3}, {4});
- auto expected = Literal::CreateR1<float>({4.0});
+ auto expected = LiteralUtil::CreateR1<float>({4.0});
EXPECT_EQ(*expected, *result);
}
TEST_F(LiteralUtilTest, SliceR2U32) {
- auto input_3x4 =
- Literal::CreateR2<uint32>({{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
+ auto input_3x4 = LiteralUtil::CreateR2<uint32>(
+ {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
auto result = input_3x4->Slice({0, 2}, {2, 4});
- auto expected = Literal::CreateR2<uint32>({{3, 4}, {7, 8}});
+ auto expected = LiteralUtil::CreateR2<uint32>({{3, 4}, {7, 8}});
EXPECT_EQ(*expected, *result);
}
TEST_F(LiteralUtilTest, SliceR3U32Full) {
- auto input_2x3x2 = Literal::CreateR3<uint32>(
+ auto input_2x3x2 = LiteralUtil::CreateR3<uint32>(
{{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}});
auto result = input_2x3x2->Slice({0, 0, 0}, {2, 3, 2});
EXPECT_EQ(*input_2x3x2, *result);
@@ -753,21 +760,21 @@ TEST_F(LiteralUtilTest, SliceR3U32Full) {
TEST_F(LiteralUtilTest, PopulateR1S64) {
Literal output(ShapeUtil::MakeShape(S64, {1}));
output.PopulateR1<int64>({77});
- auto expected = Literal::CreateR1<int64>({77});
+ auto expected = LiteralUtil::CreateR1<int64>({77});
EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, PopulateR1U64) {
Literal output(ShapeUtil::MakeShape(U64, {2}));
output.PopulateR1<uint64>({{77, 88}});
- auto expected = Literal::CreateR1<uint64>({{77, 88}});
+ auto expected = LiteralUtil::CreateR1<uint64>({{77, 88}});
EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, PopulateR1C64) {
Literal output(ShapeUtil::MakeShape(C64, {1}));
output.PopulateR1<complex64>({{77, 88}});
- auto expected = Literal::CreateR1<complex64>({{77, 88}});
+ auto expected = LiteralUtil::CreateR1<complex64>({{77, 88}});
EXPECT_EQ(output, *expected);
}
@@ -775,7 +782,7 @@ TEST_F(LiteralUtilTest, PopulateR2C64) {
Literal output(ShapeUtil::MakeShape(C64, {2, 2}));
output.PopulateR2<complex64>({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}});
auto expected =
- Literal::CreateR2<complex64>({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}});
+ LiteralUtil::CreateR2<complex64>({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}});
EXPECT_EQ(output, *expected);
}
@@ -783,7 +790,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) {
Literal output(ShapeUtil::MakeShape(BF16, {}));
bfloat16 h(0.25f);
output.PopulateWithValue<bfloat16>(h);
- auto expected = Literal::CreateR0<bfloat16>(h);
+ auto expected = LiteralUtil::CreateR0<bfloat16>(h);
EXPECT_EQ(output, *expected);
}
@@ -791,7 +798,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) {
Literal output(ShapeUtil::MakeShape(BF16, {3}));
bfloat16 h(0.5f);
output.PopulateWithValue<bfloat16>(h);
- auto expected = Literal::CreateR1<bfloat16>({h, h, h});
+ auto expected = LiteralUtil::CreateR1<bfloat16>({h, h, h});
EXPECT_EQ(output, *expected);
}
@@ -799,28 +806,28 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) {
Literal output(ShapeUtil::MakeShape(BF16, {2, 2}));
bfloat16 h(2.0f);
output.PopulateWithValue<bfloat16>(h);
- auto expected = Literal::CreateR2<bfloat16>({{h, h}, {h, h}});
+ auto expected = LiteralUtil::CreateR2<bfloat16>({{h, h}, {h, h}});
EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR0F32) {
Literal output(ShapeUtil::MakeShape(F32, {}));
output.PopulateWithValue<float>(2.5f);
- auto expected = Literal::CreateR0<float>(2.5f);
+ auto expected = LiteralUtil::CreateR0<float>(2.5f);
EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR1S64) {
Literal output(ShapeUtil::MakeShape(S64, {3}));
output.PopulateWithValue<int64>(-7);
- auto expected = Literal::CreateR1<int64>({-7, -7, -7});
+ auto expected = LiteralUtil::CreateR1<int64>({-7, -7, -7});
EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR2U64) {
Literal output(ShapeUtil::MakeShape(U64, {2, 2}));
output.PopulateWithValue<uint64>(42);
- auto expected = Literal::CreateR2<uint64>({{42, 42}, {42, 42}});
+ auto expected = LiteralUtil::CreateR2<uint64>({{42, 42}, {42, 42}});
EXPECT_EQ(output, *expected);
}
@@ -828,7 +835,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2C64) {
Literal output(ShapeUtil::MakeShape(C64, {2, 2}));
output.PopulateWithValue<complex64>({4, 2});
auto expected =
- Literal::CreateR2<complex64>({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}});
+ LiteralUtil::CreateR2<complex64>({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}});
EXPECT_EQ(output, *expected);
}
@@ -836,7 +843,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR0F16) {
Literal output(ShapeUtil::MakeShape(F16, {}));
half h(0.25f);
output.PopulateWithValue<half>(h);
- auto expected = Literal::CreateR0<half>(h);
+ auto expected = LiteralUtil::CreateR0<half>(h);
EXPECT_EQ(output, *expected);
}
@@ -844,7 +851,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR1F16) {
Literal output(ShapeUtil::MakeShape(F16, {3}));
half h(0.5f);
output.PopulateWithValue<half>(h);
- auto expected = Literal::CreateR1<half>({h, h, h});
+ auto expected = LiteralUtil::CreateR1<half>({h, h, h});
EXPECT_EQ(output, *expected);
}
@@ -852,15 +859,15 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2F16) {
Literal output(ShapeUtil::MakeShape(F16, {2, 2}));
half h(2.0f);
output.PopulateWithValue<half>(h);
- auto expected = Literal::CreateR2<half>({{h, h}, {h, h}});
+ auto expected = LiteralUtil::CreateR2<half>({{h, h}, {h, h}});
EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, ReplicateR2U32) {
- auto input =
- Literal::CreateR2<uint32>({{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
+ auto input = LiteralUtil::CreateR2<uint32>(
+ {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
auto output = input->Replicate<uint32>(3);
- auto expected = Literal::CreateR3<uint32>(
+ auto expected = LiteralUtil::CreateR3<uint32>(
{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}});
@@ -914,12 +921,12 @@ TEST_F(LiteralUtilTest, CopySliceFrom) {
}
TEST_F(LiteralUtilTest, CopyFromScalars) {
- auto zero = Literal::CreateR0<uint32>(0);
- auto nine = Literal::CreateR0<uint32>(9);
+ auto zero = LiteralUtil::CreateR0<uint32>(0);
+ auto nine = LiteralUtil::CreateR0<uint32>(9);
TF_EXPECT_OK(zero->CopyFrom(*nine));
EXPECT_EQ(*zero, *nine);
- auto vect = Literal::CreateR1<uint32>({3, 4, 9, 12, 5, 17, 21});
+ auto vect = LiteralUtil::CreateR1<uint32>({3, 4, 9, 12, 5, 17, 21});
TF_EXPECT_OK(zero->CopySliceFrom(*vect, {5}, {}, {}));
EXPECT_EQ(zero->Get<uint32>({}), 17);
TF_EXPECT_OK(vect->CopySliceFrom(*zero, {}, {4}, {}));
@@ -928,13 +935,13 @@ TEST_F(LiteralUtilTest, CopyFromScalars) {
TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) {
const Shape empty_r1_shape = ShapeUtil::MakeShape(F32, {0});
- const auto const_nine = Literal::CreateR1<float>({9});
+ const auto const_nine = LiteralUtil::CreateR1<float>({9});
const auto const_empty = Literal::CreateFromShape(empty_r1_shape);
{
// Source contains dimension with zero elements.
const auto empty = Literal::CreateFromShape(empty_r1_shape);
- auto nine = Literal::CreateR1<float>({9});
+ auto nine = LiteralUtil::CreateR1<float>({9});
TF_EXPECT_OK(nine->CopySliceFrom(*empty, {0}, {0}, {0}));
EXPECT_EQ(*nine, *const_nine);
@@ -943,7 +950,7 @@ TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) {
{
// Copy 0 element to destination with zero elements.
const auto empty = Literal::CreateFromShape(empty_r1_shape);
- auto nine = Literal::CreateR1<float>({9});
+ auto nine = LiteralUtil::CreateR1<float>({9});
TF_EXPECT_OK(empty->CopySliceFrom(*nine, {0}, {0}, {0}));
EXPECT_EQ(*empty, *const_empty);
@@ -958,16 +965,16 @@ TEST_F(LiteralUtilTest, CopyFromNilShape) {
}
TEST_F(LiteralUtilTest, CopyFromArrays) {
- auto scalar_42 = Literal::CreateR0<float>(42.0);
- auto scalar_123 = Literal::CreateR0<float>(123.0);
+ auto scalar_42 = LiteralUtil::CreateR0<float>(42.0);
+ auto scalar_123 = LiteralUtil::CreateR0<float>(123.0);
EXPECT_NE(*scalar_42, *scalar_123);
TF_ASSERT_OK(scalar_42->CopyFrom(*scalar_123, /*dest_shape_index=*/{},
/*src_shape_index=*/{}));
EXPECT_EQ(*scalar_42, *scalar_123);
EXPECT_EQ(scalar_42->Get<float>({}), 123.0f);
- auto matrix_1234 = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- auto matrix_5678 = Literal::CreateR2<float>({{5.0, 6.0}, {7.0, 8.0}});
+ auto matrix_1234 = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ auto matrix_5678 = LiteralUtil::CreateR2<float>({{5.0, 6.0}, {7.0, 8.0}});
EXPECT_NE(*matrix_1234, *matrix_5678);
EXPECT_EQ(matrix_1234->Get<float>({0, 0}), 1.0f);
TF_ASSERT_OK(matrix_1234->CopyFrom(*matrix_5678, /*dest_shape_index=*/{},
@@ -977,19 +984,19 @@ TEST_F(LiteralUtilTest, CopyFromArrays) {
}
TEST_F(LiteralUtilTest, CopyFromTuples) {
- auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
Literal nil_literal(ShapeUtil::MakeNil());
- auto nested_tuple = Literal::MakeTuple(
+ auto nested_tuple = LiteralUtil::MakeTuple(
{matrix.get(),
- Literal::MakeTuple({Literal::CreateR0<int32>(42).get(),
- Literal::CreateR1<double>({23.0, 44.0}).get(),
- &nil_literal})
+ LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR0<int32>(42).get(),
+ LiteralUtil::CreateR1<double>({23.0, 44.0}).get(), &nil_literal})
.get()});
// Create a tuple the same shape as the inner tuple of nested_tuple but with
// different values..
- auto tuple = Literal::MakeTuple({Literal::CreateR0<int32>(-5).get(),
- Literal::CreateR1<double>({2.0, 4.0}).get(),
- &nil_literal});
+ auto tuple = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR0<int32>(-5).get(),
+ LiteralUtil::CreateR1<double>({2.0, 4.0}).get(), &nil_literal});
EXPECT_EQ(*matrix, LiteralSlice(*nested_tuple, {0}));
EXPECT_EQ(nested_tuple->Get<int32>({}, {1, 0}), 42);
@@ -1010,8 +1017,8 @@ TEST_F(LiteralUtilTest, CopyFromTuples) {
EXPECT_EQ(nested_tuple->Get<double>({1}, {1, 1}), 4.0);
}
TEST_F(LiteralUtilTest, CopyBetweenSameTuple) {
- auto tuple = Literal::MakeTuple(
- {Literal::CreateR0<int32>(-2).get(), Literal::CreateR0<int32>(4).get()});
+ auto tuple = LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int32>(-2).get(),
+ LiteralUtil::CreateR0<int32>(4).get()});
EXPECT_EQ(tuple->Get<int32>({}, {0}), -2);
EXPECT_EQ(tuple->Get<int32>({}, {1}), 4);
@@ -1025,8 +1032,8 @@ TEST_F(LiteralUtilTest, CopyBetweenSameTuple) {
}
TEST_F(LiteralUtilTest, CopyFromDifferentShapes) {
- auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- auto vector = Literal::CreateR1<float>({5.0, 7.0});
+ auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ auto vector = LiteralUtil::CreateR1<float>({5.0, 7.0});
Status status = matrix->CopyFrom(*vector);
ASSERT_FALSE(status.ok());
ASSERT_THAT(status.error_message(),
@@ -1051,7 +1058,7 @@ TEST_F(LiteralUtilTest, F16) {
half h1(1.0f);
half h2(2.0f);
- auto m2 = Literal::CreateR2<half>({{h1, h2}, {h2, h1}});
+ auto m2 = LiteralUtil::CreateR2<half>({{h1, h2}, {h2, h1}});
Literal* l2 = m2.get();
const char* d2 = reinterpret_cast<const char*>(l2->data<half>().data());
EXPECT_EQ(d2[0], 0);
@@ -1150,12 +1157,12 @@ TEST_F(LiteralUtilTest, PopulateParallel) {
TEST_F(LiteralUtilTest, ConvertR4) {
// clang-format off
- auto original = Literal::CreateR4WithLayout<int8>({{
+ auto original = LiteralUtil::CreateR4WithLayout<int8>({{
{{10, 11, 12, 13}, {14, 15, 16, 17}},
{{18, 19, 20, 21}, {22, 23, 24, 25}},
{{26, 27, 28, 29}, {30, 31, 32, 33}},
}}, layout_r4_dim0major_);
- auto expected = Literal::CreateR4WithLayout<uint32>({{
+ auto expected = LiteralUtil::CreateR4WithLayout<uint32>({{
{{10, 11, 12, 13}, {14, 15, 16, 17}},
{{18, 19, 20, 21}, {22, 23, 24, 25}},
{{26, 27, 28, 29}, {30, 31, 32, 33}},
@@ -1169,42 +1176,42 @@ TEST_F(LiteralUtilTest, ConvertR4) {
TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
// clang-format off
- auto s8 = Literal::CreateR4WithLayout<int8>({{
+ auto s8 = LiteralUtil::CreateR4WithLayout<int8>({{
{{10, 0, 12, 0}, {0, 15, 0, 17}},
{{0, 19, 0, 21}, {22, 0, 24, 0}},
{{26, 0, 28, 0}, {0, 31, 0, 33}},
}}, layout_r4_dim0major_);
- auto s32 = Literal::CreateR4WithLayout<int32>({{
+ auto s32 = LiteralUtil::CreateR4WithLayout<int32>({{
{{10, 0, 12, 0}, {0, 15, 0, 17}},
{{0, 19, 0, 21}, {22, 0, 24, 0}},
{{26, 0, 28, 0}, {0, 31, 0, 33}},
}}, layout_r4_dim0major_);
- auto u32 = Literal::CreateR4WithLayout<uint32>({{
+ auto u32 = LiteralUtil::CreateR4WithLayout<uint32>({{
{{10, 0, 12, 0}, {0, 15, 0, 17}},
{{0, 19, 0, 21}, {22, 0, 24, 0}},
{{26, 0, 28, 0}, {0, 31, 0, 33}},
}}, layout_r4_dim0major_);
- auto s64 = Literal::CreateR4WithLayout<int64>({{
+ auto s64 = LiteralUtil::CreateR4WithLayout<int64>({{
{{10, 0, 12, 0}, {0, 15, 0, 17}},
{{0, 19, 0, 21}, {22, 0, 24, 0}},
{{26, 0, 28, 0}, {0, 31, 0, 33}},
}}, layout_r4_dim0major_);
- auto u64 = Literal::CreateR4WithLayout<uint64>({{
+ auto u64 = LiteralUtil::CreateR4WithLayout<uint64>({{
{{10, 0, 12, 0}, {0, 15, 0, 17}},
{{0, 19, 0, 21}, {22, 0, 24, 0}},
{{26, 0, 28, 0}, {0, 31, 0, 33}},
}}, layout_r4_dim0major_);
- auto pred = Literal::CreateR4WithLayout<bool>({{
+ auto pred = LiteralUtil::CreateR4WithLayout<bool>({{
{{true, false, true, false}, {false, true, false, true}},
{{false, true, false, true}, {true, false, true, false}},
{{true, false, true, false}, {false, true, false, true}},
}}, layout_r4_dim0major_);
- auto int32_pred = Literal::CreateR4WithLayout<int32>({{
+ auto int32_pred = LiteralUtil::CreateR4WithLayout<int32>({{
{{1, 0, 1, 0}, {0, 1, 0, 1}},
{{0, 1, 0, 1}, {1, 0, 1, 0}},
{{1, 0, 1, 0}, {0, 1, 0, 1}},
}}, layout_r4_dim0major_);
- auto f16 = Literal::CreateR4WithLayout<half>({{
+ auto f16 = LiteralUtil::CreateR4WithLayout<half>({{
{{half(10.0), half(0.0), half(12.0), half(0.0)},
{half(0.0), half(15.0), half(0.0), half(17.0)}},
{{half(0.0), half(19.0), half(0.0), half(21.0)},
@@ -1212,7 +1219,7 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
{{half(26.0), half(0.0), half(28.0), half(0.0)},
{half(0.0), half(31.0), half(0.0), half(33.0)}},
}}, layout_r4_dim0major_);
- auto bf16 = Literal::CreateR4WithLayout<bfloat16>({{
+ auto bf16 = LiteralUtil::CreateR4WithLayout<bfloat16>({{
{{bfloat16(10.0), bfloat16(0.0), bfloat16(12.0), bfloat16(0.0)},
{bfloat16(0.0), bfloat16(15.0), bfloat16(0.0), bfloat16(17.0)}},
{{bfloat16(0.0), bfloat16(19.0), bfloat16(0.0), bfloat16(21.0)},
@@ -1220,17 +1227,17 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
{{bfloat16(26.0), bfloat16(0.0), bfloat16(28.0), bfloat16(0.0)},
{bfloat16(0.0), bfloat16(31.0), bfloat16(0.0), bfloat16(33.0)}},
}}, layout_r4_dim0major_);
- auto f32 = Literal::CreateR4WithLayout<float>({{
+ auto f32 = LiteralUtil::CreateR4WithLayout<float>({{
{{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}},
{{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}},
{{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}},
}}, layout_r4_dim0major_);
- auto f64 = Literal::CreateR4WithLayout<double>({{
+ auto f64 = LiteralUtil::CreateR4WithLayout<double>({{
{{10.0, 0.0, 12.0, 0.0}, {0.0, 15.0, 0.0, 17.0}},
{{0.0, 19.0, 0.0, 21.0}, {22.0, 0.0, 24.0, 0.0}},
{{26.0, 0.0, 28.0, 0.0}, {0.0, 31.0, 0.0, 33.0}},
}}, layout_r4_dim0major_);
- auto c64 = Literal::CreateR4WithLayout<complex64>({{
+ auto c64 = LiteralUtil::CreateR4WithLayout<complex64>({{
{{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}},
{{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}},
{{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}},
@@ -1302,18 +1309,18 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
}
TEST_F(LiteralUtilTest, BitcastConvert) {
- auto original =
- Literal::CreateR1<uint32>({tensorflow::bit_cast<uint32>(2.5f),
- tensorflow::bit_cast<uint32>(-42.25f),
- tensorflow::bit_cast<uint32>(100.f), 0xbeef});
- auto expected = Literal::CreateR1<float>(
+ auto original = LiteralUtil::CreateR1<uint32>(
+ {tensorflow::bit_cast<uint32>(2.5f),
+ tensorflow::bit_cast<uint32>(-42.25f),
+ tensorflow::bit_cast<uint32>(100.f), 0xbeef});
+ auto expected = LiteralUtil::CreateR1<float>(
{2.5f, -42.25f, 100.0f, tensorflow::bit_cast<float>(0xbeef)});
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> converted,
original->BitcastConvert(F32));
}
TEST_F(LiteralUtilTest, BitcastConvertBetweenInvalidTypes) {
- auto literal = Literal::CreateR0<uint32>(1234);
+ auto literal = LiteralUtil::CreateR0<uint32>(1234);
Status status = literal->BitcastConvert(F64).status();
EXPECT_NE(Status::OK(), status);
EXPECT_TRUE(tensorflow::str_util::StrContains(status.error_message(),
@@ -1348,7 +1355,7 @@ TEST_F(LiteralUtilTest, ToProto_f16) {
half h1(1.0f);
half h2(2.0f);
- auto m = Literal::CreateR2<half>({{h1, h2}, {h2, h1}});
+ auto m = LiteralUtil::CreateR2<half>({{h1, h2}, {h2, h1}});
Literal* l = m.get();
EXPECT_EQ(4, ShapeUtil::ElementsIn(l->shape()));
EXPECT_EQ(4, l->data<half>().size());
@@ -1391,10 +1398,10 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) {
}
TEST_F(LiteralUtilTest, LiteralSliceTest) {
- auto scalar = Literal::CreateR0<float>(1.0);
- auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()});
- auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()});
+ auto scalar = LiteralUtil::CreateR0<float>(1.0);
+ auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
+ auto nested_tuple = LiteralUtil::MakeTuple({tuple.get(), scalar.get()});
Literal nil(ShapeUtil::MakeNil());
EXPECT_EQ(LiteralSlice(*scalar, {}), *scalar);
@@ -1413,10 +1420,10 @@ TEST_F(LiteralUtilTest, LiteralSliceTest) {
}
TEST_F(LiteralUtilTest, MutatingLiteralSlice) {
- auto scalar = Literal::CreateR0<float>(1.0);
- auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()});
- auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()});
+ auto scalar = LiteralUtil::CreateR0<float>(1.0);
+ auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
+ auto nested_tuple = LiteralUtil::MakeTuple({tuple.get(), scalar.get()});
// Verify that changing the underlying data beneath the view changes the
// data of the view itself.
const auto nested_tuple_view = LiteralSlice(*nested_tuple);
@@ -1436,15 +1443,16 @@ TEST_F(LiteralUtilTest, MutatingLiteralSlice) {
}
TEST_F(LiteralUtilTest, LiteralSliceOfALiteralSlice) {
- auto scalar = Literal::CreateR0<float>(1.0);
- auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()});
- auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()});
+ auto scalar = LiteralUtil::CreateR0<float>(1.0);
+ auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
+ auto nested_tuple = LiteralUtil::MakeTuple({tuple.get(), scalar.get()});
const auto nested_tuple_view = LiteralSlice(*nested_tuple);
const auto tuple_view = LiteralSlice(nested_tuple_view, /*view_root=*/{0});
const auto matrix_view = LiteralSlice(tuple_view, /*view_root=*/{1});
- EXPECT_EQ(matrix_view, *Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}));
+ EXPECT_EQ(matrix_view,
+ *LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}));
}
TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtr) {
@@ -1488,7 +1496,7 @@ TEST_F(LiteralUtilTest, BorrowingLiteralFromMultipleBufferPtrs) {
TEST_F(LiteralUtilTest, LiteralMove) {
std::unique_ptr<Literal> matrix =
- Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
Literal literal(std::move(*matrix));
EXPECT_TRUE(
@@ -1501,11 +1509,11 @@ TEST_F(LiteralUtilTest, LiteralMove) {
TEST_F(LiteralUtilTest, DecomposeTuple) {
Literal nil_literal(ShapeUtil::MakeNil());
- auto nested_tuple = Literal::MakeTuple(
- {Literal::CreateR2<int32>({{1, 2}, {3, 4}}).get(),
- Literal::MakeTuple({Literal::CreateR0<int32>(42).get(),
- Literal::CreateR1<double>({23.0, 44.0}).get(),
- &nil_literal})
+ auto nested_tuple = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}}).get(),
+ LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR0<int32>(42).get(),
+ LiteralUtil::CreateR1<double>({23.0, 44.0}).get(), &nil_literal})
.get(),
&nil_literal});
@@ -1542,13 +1550,13 @@ TEST_F(LiteralUtilTest, DecomposeEmptyTuple) {
TEST_F(LiteralUtilTest, MoveIntoTuple) {
std::vector<Literal> elements;
- elements.push_back(std::move(*Literal::CreateR0<float>(1.0)));
- elements.push_back(std::move(*Literal::CreateR1<int32>({4, 8})));
- elements.push_back(std::move(
- *Literal::MakeTuple({Literal::CreateR0<int32>(42).get(),
- Literal::CreateR1<double>({23.0, 44.0}).get()})
+ elements.push_back(std::move(*LiteralUtil::CreateR0<float>(1.0)));
+ elements.push_back(std::move(*LiteralUtil::CreateR1<int32>({4, 8})));
+ elements.push_back(std::move(*LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR0<int32>(42).get(),
+ LiteralUtil::CreateR1<double>({23.0, 44.0}).get()})
- ));
+ ));
Literal literal = Literal::MoveIntoTuple(&elements);
ASSERT_TRUE(ShapeUtil::IsTuple(literal.shape()));
@@ -1577,7 +1585,7 @@ TEST_F(LiteralUtilTest, LiteralMoveAssignment) {
EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeNil(), literal.shape()));
std::unique_ptr<Literal> matrix =
- Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
literal = std::move(*matrix);
EXPECT_TRUE(
@@ -1590,7 +1598,7 @@ TEST_F(LiteralUtilTest, LiteralMoveAssignment) {
TEST_F(LiteralUtilTest, LiteralSliceCopy) {
std::unique_ptr<Literal> matrix =
- Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
const auto matrix_view = LiteralSlice(*matrix);
LiteralSlice matrix_view_copy(matrix_view);
@@ -1601,9 +1609,9 @@ TEST_F(LiteralUtilTest, LiteralSliceCopy) {
}
TEST_F(LiteralUtilTest, GetSetTuple) {
- auto tuple = Literal::MakeTuple(
- {Literal::CreateR0<float>(42.0).get(),
- Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}).get()});
+ auto tuple = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR0<float>(42.0).get(),
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}).get()});
EXPECT_EQ(tuple->Get<float>(/*multi_index=*/{}, /*shape_index=*/{0}), 42.0);
tuple->Set<float>(/*multi_index=*/{}, /*shape_index=*/{0}, -5.0);
EXPECT_EQ(tuple->Get<float>(/*multi_index=*/{}, /*shape_index=*/{0}), -5.0);
@@ -1644,20 +1652,20 @@ TEST_F(LiteralUtilTest, CreateFromShapeZeroInitialized) {
TEST_F(LiteralUtilTest, ProtoRoundTrip) {
// Test serializing then deserializing a Literal through a proto.
- auto one_f32 = Literal::CreateR0<float>(1.0);
- auto two_f32 = Literal::CreateR0<float>(2.0);
- auto vector_int8 = Literal::CreateR1<int8>({-128, 0, 2, 4, 7, 56, 127});
- auto vector_c64 = Literal::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
- auto vector_bfloat16 = Literal::CreateR1<bfloat16>(
+ auto one_f32 = LiteralUtil::CreateR0<float>(1.0);
+ auto two_f32 = LiteralUtil::CreateR0<float>(2.0);
+ auto vector_int8 = LiteralUtil::CreateR1<int8>({-128, 0, 2, 4, 7, 56, 127});
+ auto vector_c64 = LiteralUtil::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
+ auto vector_bfloat16 = LiteralUtil::CreateR1<bfloat16>(
{bfloat16{-1.0}, bfloat16{2.0}, bfloat16{-3.0}});
auto vector_half =
- Literal::CreateR1<half>({half{10.0}, half{20.0}, half{-30.0}});
+ LiteralUtil::CreateR1<half>({half{10.0}, half{20.0}, half{-30.0}});
auto matrix_pred =
- Literal::CreateR2<bool>({{true, false, true}, {false, false, true}});
- auto tuple = Literal::MakeTuple(
+ LiteralUtil::CreateR2<bool>({{true, false, true}, {false, false, true}});
+ auto tuple = LiteralUtil::MakeTuple(
{one_f32.get(), vector_half.get(), matrix_pred.get(), matrix_pred.get()});
Literal nil_literal(ShapeUtil::MakeNil());
- auto nested_tuple = Literal::MakeTuple(
+ auto nested_tuple = LiteralUtil::MakeTuple(
{tuple.get(), vector_bfloat16.get(), tuple.get(), &nil_literal});
auto to_from_proto = [](const Literal& literal) -> Literal {
@@ -1790,8 +1798,8 @@ TEST_F(LiteralUtilTest, InvalidProtoTooManyTupleElements) {
}
TEST_F(LiteralUtilTest, SortSparseElements) {
- auto literal =
- Literal::CreateSparse<float>({10, 10, 10}, SparseIndexArray(10, 3), {});
+ auto literal = LiteralUtil::CreateSparse<float>({10, 10, 10},
+ SparseIndexArray(10, 3), {});
literal->AppendSparseElement<float>({2, 3, 4}, 2.0);
literal->AppendSparseElement<float>({3, 4, 5}, 3.0);
literal->AppendSparseElement<float>({1, 2, 3}, 1.0);
@@ -1805,21 +1813,22 @@ TEST_F(LiteralUtilTest, GetSparseElementAsString) {
SparseIndexArray indices(10, {{1, 2, 3}, {2, 3, 4}, {3, 4, 5}});
ASSERT_EQ(
- Literal::CreateSparse<bool>(dimensions, indices, {true, false, true})
+ LiteralUtil::CreateSparse<bool>(dimensions, indices, {true, false, true})
->GetSparseElementAsString(1),
"false");
- ASSERT_EQ(Literal::CreateSparse<int64>(dimensions, indices, {1, 2, 3})
+ ASSERT_EQ(LiteralUtil::CreateSparse<int64>(dimensions, indices, {1, 2, 3})
->GetSparseElementAsString(1),
tensorflow::strings::StrCat(int64{2}));
- ASSERT_EQ(Literal::CreateSparse<double>(dimensions, indices, {1.0, 2.0, 3.0})
- ->GetSparseElementAsString(1),
- tensorflow::strings::StrCat(double{2.0}));
- ASSERT_EQ(Literal::CreateSparse<half>(dimensions, indices,
- {half{1.0}, half{2.0}, half{3.0}})
+ ASSERT_EQ(
+ LiteralUtil::CreateSparse<double>(dimensions, indices, {1.0, 2.0, 3.0})
+ ->GetSparseElementAsString(1),
+ tensorflow::strings::StrCat(double{2.0}));
+ ASSERT_EQ(LiteralUtil::CreateSparse<half>(dimensions, indices,
+ {half{1.0}, half{2.0}, half{3.0}})
->GetSparseElementAsString(1),
tensorflow::strings::StrCat(static_cast<float>(half{2.0})));
ASSERT_EQ(
- Literal::CreateSparse<complex64>(
+ LiteralUtil::CreateSparse<complex64>(
dimensions, indices,
std::vector<complex64>{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}})
->GetSparseElementAsString(1),
@@ -1827,33 +1836,36 @@ TEST_F(LiteralUtilTest, GetSparseElementAsString) {
}
TEST_F(LiteralUtilTest, BroadcastVectorToMatrix0) {
- std::unique_ptr<Literal> literal = Literal::CreateR1<int64>({1, 2});
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<int64>({1, 2});
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Literal> broadcasted_literal,
literal->Broadcast(
/*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
/*dimensions=*/{0}));
- EXPECT_EQ(*broadcasted_literal, *Literal::CreateR2<int64>({{1, 1}, {2, 2}}));
+ EXPECT_EQ(*broadcasted_literal,
+ *LiteralUtil::CreateR2<int64>({{1, 1}, {2, 2}}));
}
TEST_F(LiteralUtilTest, BroadcastVectorToMatrix1) {
- std::unique_ptr<Literal> literal = Literal::CreateR1<int64>({1, 2});
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<int64>({1, 2});
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Literal> broadcasted_literal,
literal->Broadcast(
/*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
/*dimensions=*/{1}));
- EXPECT_EQ(*broadcasted_literal, *Literal::CreateR2<int64>({{1, 2}, {1, 2}}));
+ EXPECT_EQ(*broadcasted_literal,
+ *LiteralUtil::CreateR2<int64>({{1, 2}, {1, 2}}));
}
TEST_F(LiteralUtilTest, BroadcastScalarToMatrix) {
- std::unique_ptr<Literal> literal = Literal::CreateR0<int32>(9);
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR0<int32>(9);
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Literal> broadcasted_literal,
literal->Broadcast(
/*result_shape=*/ShapeUtil::MakeShape(S32, {2, 2}),
/*dimensions=*/{}));
- EXPECT_EQ(*broadcasted_literal, *Literal::CreateR2<int32>({{9, 9}, {9, 9}}));
+ EXPECT_EQ(*broadcasted_literal,
+ *LiteralUtil::CreateR2<int32>({{9, 9}, {9, 9}}));
}
} // namespace
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc
index eeabf835ac..548fbe8a83 100644
--- a/tensorflow/compiler/xla/literal_util.cc
+++ b/tensorflow/compiler/xla/literal_util.cc
@@ -43,25 +43,6 @@ namespace xla {
namespace {
-constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__;
-
-// Converts between little and big endian.
-//
-// Precondition: size % 2 == 0 (elements in the array are 16 bits long)
-void ConvertEndianShort(string* bytes) {
- CHECK_EQ(bytes->size() / 2, 0);
- for (int64 i = 0; i < bytes->size(); i += 2) {
- std::swap((*bytes)[i], (*bytes)[i + 1]);
- }
-}
-
-void ConvertEndianShort(char* bytes, int64 size) {
- CHECK_EQ(size / 2, 0);
- for (int64 i = 0; i < size; i += 2) {
- std::swap(bytes[i], bytes[i + 1]);
- }
-}
-
// Return a literal with all arrays of type FromNativeT converted to type
// ToNativeT in the given literal.
template <typename FromNativeT, typename ToNativeT>
@@ -103,505 +84,54 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
} // namespace
-LiteralBase::~LiteralBase() {}
-
-std::ostream& operator<<(std::ostream& out, const Literal& literal) {
- out << literal.ToString();
- return out;
-}
-
-Literal::StrideConfig::StrideConfig(
- const Shape& source_shape, const Shape& dest_shape,
- tensorflow::gtl::ArraySlice<int64> dimensions)
- : dimensions(dimensions),
- base(dimensions.size(), 0),
- step(dimensions.size(), 1) {
- if (!dimensions.empty()) {
- // Selects the shape with the largest minor dimension as the one upon
- // which to run the tight stride loop.
- if (dimensions[LayoutUtil::Minor(source_shape.layout(), 0)] >=
- dimensions[LayoutUtil::Minor(dest_shape.layout(), 0)]) {
- minor_dimension = LayoutUtil::Minor(source_shape.layout(), 0);
- dest_stride = IndexUtil::GetDimensionStride(dest_shape, minor_dimension);
- } else {
- minor_dimension = LayoutUtil::Minor(dest_shape.layout(), 0);
- source_stride =
- IndexUtil::GetDimensionStride(source_shape, minor_dimension);
- }
- minor_loop_size = dimensions[minor_dimension];
- step[minor_dimension] = minor_loop_size;
- }
-}
-
-Literal::Literal(const Shape& shape)
- : Literal(shape, /*allocate_arrays=*/true) {}
-
-void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) {
- if (ShapeUtil::IsTuple(shape)) {
- for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
- const Shape& subshape = shape.tuple_shapes(i);
-
- auto child_piece = Piece();
- child_piece.set_subshape(&subshape);
-
- SetPiece(subshape, &child_piece, allocate_arrays);
-
- piece->emplace_back(std::move(child_piece));
- }
- } else if (ShapeUtil::IsArray(shape)) {
- if (allocate_arrays) {
- if (LayoutUtil::IsSparseArray(shape)) {
- // For sparse arrays, the buffer must be of the size of the maximum
- // number of sparse elements possible.
- const int64 max_sparse_elements =
- LayoutUtil::MaxSparseElements(shape.layout());
- piece->set_buffer(
- new char[max_sparse_elements *
- ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type())]);
- piece->set_sparse_indices(
- new SparseIndexArray(max_sparse_elements, ShapeUtil::Rank(shape)));
- } else {
- piece->set_buffer(new char[piece->size_bytes()]);
- }
- }
- } else {
- // If the shape is neither an array nor tuple, then it must be
- // zero-sized. Otherwise, some memory needs to be allocated for it.
- CHECK_EQ(piece->size_bytes(), 0);
- }
-}
-
-Literal::Literal(const Shape& shape, bool allocate_arrays)
- : LiteralBase(), shape_(MakeUnique<Shape>(shape)) {
- CHECK(LayoutUtil::HasLayout(*shape_));
- root_piece_ = new Piece();
- root_piece_->set_subshape(shape_.get());
- CHECK(&root_piece_->subshape() == shape_.get());
-
- SetPiece(*shape_, root_piece_, allocate_arrays);
-}
-
-Literal::~Literal() {
- if (root_piece_ != nullptr) {
- DeallocateBuffers();
- delete root_piece_;
- }
-}
-
-void Literal::DeallocateBuffers() {
- root_piece_->ForEachMutableSubpiece(
- [&](const ShapeIndex& index, Piece* piece) {
- if (piece->buffer() != nullptr) {
- delete[] piece->buffer();
- delete piece->sparse_indices();
- }
- });
-}
-
-Literal::Literal(Literal&& other) : LiteralBase() { *this = std::move(other); }
-
-Literal& Literal::operator=(Literal&& other) {
- DCHECK(&other.root_piece_->subshape() == other.shape_.get());
- using std::swap;
- swap(shape_, other.shape_);
- swap(root_piece_, other.root_piece_);
- DCHECK(&root_piece_->subshape() == shape_.get());
-
- return *this;
-}
-
-std::unique_ptr<Literal> LiteralBase::CreateFromShape(const Shape& shape) {
- auto literal = MakeUnique<Literal>(shape);
- literal->root_piece_->ForEachMutableSubpiece(
- [&](const ShapeIndex& index, Piece* piece) {
- if (ShapeUtil::IsArray(piece->subshape())) {
- memset(piece->untyped_data(), 0, piece->size_bytes());
- }
- });
- return literal;
-}
-
-const SparseIndexArray* LiteralBase::sparse_indices(
- const ShapeIndex& shape_index) const {
- return piece(shape_index).sparse_indices();
-}
-
-SparseIndexArray* Literal::sparse_indices(const ShapeIndex& shape_index) {
- return piece(shape_index).sparse_indices();
-}
-
-/* static */ std::unique_ptr<Literal> Literal::CreateFromDimensions(
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateFromDimensions(
PrimitiveType primitive_type,
tensorflow::gtl::ArraySlice<int64> dimensions) {
- return CreateFromShape(ShapeUtil::MakeShape(primitive_type, dimensions));
+ return Literal::CreateFromShape(
+ ShapeUtil::MakeShape(primitive_type, dimensions));
}
-/* static */ std::unique_ptr<Literal> Literal::ConvertBF16ToF32(
+/* static */ std::unique_ptr<Literal> LiteralUtil::ConvertBF16ToF32(
const LiteralSlice& bf16_literal) {
return ConvertType<bfloat16, float>(bf16_literal);
}
-/* static */ std::unique_ptr<Literal> Literal::ConvertF32ToBF16(
+/* static */ std::unique_ptr<Literal> LiteralUtil::ConvertF32ToBF16(
const LiteralSlice& f32_literal) {
return ConvertType<float, bfloat16>(f32_literal);
}
-template <typename NativeT>
-Status Literal::CopySliceFromInternal(
- const LiteralBase& src_literal, tensorflow::gtl::ArraySlice<int64> src_base,
- tensorflow::gtl::ArraySlice<int64> dest_base,
- tensorflow::gtl::ArraySlice<int64> copy_size) {
- TF_RET_CHECK(ShapeUtil::Rank(src_literal.shape()) == src_base.size());
- TF_RET_CHECK(ShapeUtil::Rank(shape()) == dest_base.size());
-
- auto linear_index = [](const Shape& shape,
- tensorflow::gtl::ArraySlice<int64> multi_index) {
- return IndexUtil::MultidimensionalIndexToLinearIndex(shape, multi_index);
- };
-
- if (ShapeUtil::Rank(src_literal.shape()) == 0 ||
- ShapeUtil::Rank(shape()) == 0) {
- // If any of the two shapes are scalars, we can just call the StridedCopy()
- // directly, and we know we will be copying only one value.
- TF_RET_CHECK(copy_size.empty());
- StridedCopy(data<NativeT>(), linear_index(shape(), dest_base), 0,
- src_literal.data<NativeT>(),
- linear_index(src_literal.shape(), src_base), 0, 1);
- } else if (!ShapeUtil::IsZeroElementArray(shape()) &&
- !ShapeUtil::IsZeroElementArray(src_literal.shape())) {
- // Perform copy if neither src nor dest has dimensions with zero element,
- // otherwise it's a no-op.
- TF_RET_CHECK(src_base.size() == dest_base.size());
- TF_RET_CHECK(src_base.size() == copy_size.size());
-
- // Scan the source from minor, stepping in copy size blocks, then within
- // the index enumaration functor, do a strided copy advancing source index
- // by one (walking through the minor dimension), and destination index by
- // proper stride size at the matching dimension.
- DimensionVector src_indexes(src_base.size(), 0);
- DimensionVector dest_indexes(dest_base.size(), 0);
- Literal::StrideConfig stride_config(src_literal.shape(), shape(),
- copy_size);
-
- auto copy_proc = [&](tensorflow::gtl::ArraySlice<int64> indexes) {
- // Map from multi-dimensional index, to source index.
- std::transform(indexes.begin(), indexes.end(), src_base.begin(),
- src_indexes.begin(), std::plus<int64>());
- // Map from multi-dimensional index, to destination index.
- std::transform(indexes.begin(), indexes.end(), dest_base.begin(),
- dest_indexes.begin(), std::plus<int64>());
-
- int64 src_index = linear_index(src_literal.shape(), src_indexes);
- int64 dest_index = linear_index(shape(), dest_indexes);
-
- // `this->` is needed to workaround MSVC bug: #16882
- StridedCopy(this->data<NativeT>(), dest_index, stride_config.dest_stride,
- src_literal.data<NativeT>(), src_index,
- stride_config.source_stride, stride_config.minor_loop_size);
- return true;
- };
-
- ShapeUtil::ForEachIndex(src_literal.shape(), stride_config.base,
- stride_config.dimensions, stride_config.step,
- copy_proc);
- }
- return Status::OK();
-}
-
-Status Literal::CopyElementFrom(const LiteralSlice& src_literal,
- tensorflow::gtl::ArraySlice<int64> src_index,
- tensorflow::gtl::ArraySlice<int64> dest_index) {
- DCHECK_EQ(shape().element_type(), src_literal.shape().element_type());
- const int64 src_linear_index = IndexUtil::MultidimensionalIndexToLinearIndex(
- src_literal.shape(), src_index);
- const int64 dest_linear_index =
- IndexUtil::MultidimensionalIndexToLinearIndex(shape(), dest_index);
- const int64 primitive_size =
- ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
-
- char* dest_address =
- static_cast<char*>(untyped_data()) + dest_linear_index * primitive_size;
- const char* source_address =
- static_cast<const char*>(src_literal.untyped_data()) +
- src_linear_index * primitive_size;
- if (dest_address != source_address) {
- memcpy(dest_address, source_address, primitive_size);
- }
- return Status::OK();
-}
-
-/* static */ std::unique_ptr<Literal> Literal::CreateToken() {
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateToken() {
return MakeUnique<Literal>(ShapeUtil::MakeTokenShape());
}
-std::vector<Literal> Literal::DecomposeTuple() {
- CHECK(ShapeUtil::IsTuple(shape()));
- std::vector<Literal> elements;
- for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) {
- elements.push_back(Literal(ShapeUtil::GetSubshape(shape(), {i}),
- /*allocate_arrays=*/false));
- Literal& element = elements.back();
- element.root_piece_->ForEachMutableSubpiece(
- [&](const ShapeIndex& index, Piece* dest_piece) {
- ShapeIndex src_index = {i};
- for (int64 j : index) {
- src_index.push_back(j);
- }
- Piece& src_piece = piece(src_index);
-
- // Move the respective buffer and sparse indices over to the element
- // Literal.
- dest_piece->set_buffer(src_piece.buffer());
- src_piece.set_buffer(nullptr);
- dest_piece->set_sparse_indices(src_piece.sparse_indices());
- src_piece.set_sparse_indices(nullptr);
- });
- }
- // Set this literal to be nil-shaped.
- *this = Literal();
- return elements;
-}
-
-/* static */ Literal Literal::MoveIntoTuple(
- tensorflow::gtl::MutableArraySlice<Literal> elements) {
- std::vector<Shape> element_shapes;
- for (const Literal& element : elements) {
- element_shapes.push_back(element.shape());
- }
- Literal literal(ShapeUtil::MakeTupleShape(element_shapes),
- /*allocate_arrays=*/false);
- for (int i = 0; i < elements.size(); ++i) {
- TF_CHECK_OK(
- literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i}));
- }
- return literal;
-}
-
-namespace {
-
-// Copies the elements in 'src' to 'dest'. The shape and layout of the data in
-// the array slices are indicated by dest_shape and src_shape respectively.
-template <typename NativeT>
-void CopyElementsBetween(tensorflow::gtl::MutableArraySlice<NativeT> dest,
- tensorflow::gtl::ArraySlice<NativeT> src,
- const Shape& dest_shape, const Shape& src_shape) {
- CHECK(ShapeUtil::Compatible(dest_shape, src_shape));
- if (ShapeUtil::IsZeroElementArray(dest_shape)) {
- return;
- }
- std::vector<int64> index(ShapeUtil::Rank(dest_shape));
- do {
- dest[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape, index)] =
- src[IndexUtil::MultidimensionalIndexToLinearIndex(src_shape, index)];
- } while (IndexUtil::BumpIndices(dest_shape, &index));
-}
-
-} // namespace
-
-Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) {
- CHECK(subshape_ != nullptr);
- CHECK(src.subshape_ != nullptr);
- if (ShapeUtil::Equal(subshape(), src.subshape())) {
- // If the layouts are equal it's faster just to memcpy.
- memcpy(buffer(), src.buffer(), src.size_bytes());
- } else {
- TF_RET_CHECK(ShapeUtil::Compatible(src.subshape(), subshape()));
- std::vector<int64> origin(ShapeUtil::Rank(subshape()), 0);
- switch (subshape().element_type()) {
-#define COPY_ELEMENTS(XLA_T, NATIVE_T) \
- case (XLA_T): \
- CopyElementsBetween<NATIVE_T>(data<NATIVE_T>(), src.data<NATIVE_T>(), \
- subshape(), src.subshape()); \
- break;
- COPY_ELEMENTS(U8, uint8);
- COPY_ELEMENTS(U16, uint16);
- COPY_ELEMENTS(U32, uint32);
- COPY_ELEMENTS(U64, uint64);
- COPY_ELEMENTS(S8, int8);
- COPY_ELEMENTS(S16, int16);
- COPY_ELEMENTS(S32, int32);
- COPY_ELEMENTS(S64, int64);
- COPY_ELEMENTS(F16, half);
- COPY_ELEMENTS(BF16, bfloat16);
- COPY_ELEMENTS(F32, float);
- COPY_ELEMENTS(F64, double);
- COPY_ELEMENTS(C64, complex64);
- COPY_ELEMENTS(PRED, bool);
-#undef COPY_ELEMENTS
- default:
- return Unimplemented(
- "Copying a Literal object with element type %s is not implemented.",
- PrimitiveType_Name(subshape().element_type()).c_str());
- }
- }
- return Status::OK();
-}
-
-Status Literal::CopyFrom(const LiteralSlice& src_literal,
- const ShapeIndex& dest_shape_index,
- const ShapeIndex& src_shape_index) {
- const Shape& dest_subshape =
- ShapeUtil::GetSubshape(shape(), dest_shape_index);
- const Shape& src_subshape =
- ShapeUtil::GetSubshape(src_literal.shape(), src_shape_index);
- if (!ShapeUtil::Compatible(dest_subshape, src_subshape)) {
- return InvalidArgument(
- "Destination subshape incompatible with source subshape: %s vs %s",
- ShapeUtil::HumanString(dest_subshape).c_str(),
- ShapeUtil::HumanString(src_subshape).c_str());
- }
- return root_piece_->ForEachMutableSubpieceWithStatus(
- [&](const ShapeIndex& index, Piece* piece) {
- if (!ShapeUtil::IsArray(piece->subshape())) {
- return Status::OK();
- }
-
- // Determine if this index is in the part of this literal that we want
- // to copy over from src_literal.
- bool in_subtree_to_copy = true;
- for (int i = 0; i < dest_shape_index.size(); ++i) {
- if (index[i] != dest_shape_index[i]) {
- in_subtree_to_copy = false;
- break;
- }
- }
- if (!in_subtree_to_copy) {
- return Status::OK();
- }
- // Construct the index of the corresponding piece in the source literal.
- ShapeIndex src_piece_index = src_shape_index;
- for (int64 i = dest_shape_index.size(); i < index.size(); ++i) {
- src_piece_index.push_back(index[i]);
- }
- TF_RETURN_IF_ERROR(piece->CopyFrom(src_literal.piece(src_piece_index)));
- return Status::OK();
- });
-}
-
-Status Literal::MoveFrom(Literal&& src_literal,
- const ShapeIndex& dest_shape_index) {
- const Shape& dest_subshape =
- ShapeUtil::GetSubshape(shape(), dest_shape_index);
- if (!ShapeUtil::Equal(dest_subshape, src_literal.shape())) {
- return InvalidArgument(
- "Destination subshape not equal to source shape: %s vs %s",
- ShapeUtil::HumanString(dest_subshape).c_str(),
- ShapeUtil::HumanString(src_literal.shape()).c_str());
- }
-
- src_literal.root_piece_->ForEachSubpiece(
- [&](const ShapeIndex& src_index, const Piece& src_piece) {
- if (!ShapeUtil::IsArray(src_piece.subshape())) {
- return;
- }
-
- ShapeIndex dest_index = dest_shape_index;
- for (int64 i : src_index) {
- dest_index.push_back(i);
- }
- Piece& dest_piece = piece(dest_index);
- delete[] dest_piece.buffer();
- dest_piece.set_buffer(src_piece.buffer());
- delete dest_piece.sparse_indices();
- dest_piece.set_sparse_indices(src_piece.sparse_indices());
- });
-
- src_literal.shape_ = MakeUnique<Shape>(ShapeUtil::MakeNil());
- delete src_literal.root_piece_;
- src_literal.root_piece_ = new LiteralBase::Piece();
- src_literal.root_piece_->set_subshape(src_literal.shape_.get());
-
- return Status::OK();
-}
-
-Status Literal::CopySliceFrom(const LiteralSlice& src_literal,
- tensorflow::gtl::ArraySlice<int64> src_base,
- tensorflow::gtl::ArraySlice<int64> dest_base,
- tensorflow::gtl::ArraySlice<int64> copy_size) {
- TF_RET_CHECK(ShapeUtil::IsArray(shape())) << ShapeUtil::HumanString(shape());
- TF_RET_CHECK(ShapeUtil::IsArray(src_literal.shape()))
- << ShapeUtil::HumanString(src_literal.shape());
- TF_RET_CHECK(ShapeUtil::SameElementType(src_literal.shape(), shape()));
-
- switch (shape().element_type()) {
- case U8:
- return CopySliceFromInternal<uint8>(src_literal, src_base, dest_base,
- copy_size);
- case U16:
- return CopySliceFromInternal<uint16>(src_literal, src_base, dest_base,
- copy_size);
- case U32:
- return CopySliceFromInternal<uint32>(src_literal, src_base, dest_base,
- copy_size);
- case U64:
- return CopySliceFromInternal<uint64>(src_literal, src_base, dest_base,
- copy_size);
- case S8:
- return CopySliceFromInternal<int8>(src_literal, src_base, dest_base,
- copy_size);
- case S16:
- return CopySliceFromInternal<int16>(src_literal, src_base, dest_base,
- copy_size);
- case S32:
- return CopySliceFromInternal<int32>(src_literal, src_base, dest_base,
- copy_size);
- case S64:
- return CopySliceFromInternal<int64>(src_literal, src_base, dest_base,
- copy_size);
- case F16:
- return CopySliceFromInternal<half>(src_literal, src_base, dest_base,
- copy_size);
- case BF16:
- return CopySliceFromInternal<bfloat16>(src_literal, src_base, dest_base,
- copy_size);
- case F32:
- return CopySliceFromInternal<float>(src_literal, src_base, dest_base,
- copy_size);
- case F64:
- return CopySliceFromInternal<double>(src_literal, src_base, dest_base,
- copy_size);
- case C64:
- return CopySliceFromInternal<complex64>(src_literal, src_base, dest_base,
- copy_size);
- case PRED:
- return CopySliceFromInternal<bool>(src_literal, src_base, dest_base,
- copy_size);
- default:
- break;
- }
- return Unimplemented(
- "Copying a slice from a Literal object with element type %d is not "
- "implemented.",
- shape().element_type());
-}
-
-/* static */ Literal Literal::Zero(PrimitiveType primitive_type) {
+/* static */ Literal LiteralUtil::Zero(PrimitiveType primitive_type) {
switch (primitive_type) {
case U8:
- return std::move(*Literal::CreateR0<uint8>(0));
+ return std::move(*LiteralUtil::CreateR0<uint8>(0));
case U32:
- return std::move(*Literal::CreateR0<uint32>(0));
+ return std::move(*LiteralUtil::CreateR0<uint32>(0));
case U64:
- return std::move(*Literal::CreateR0<uint64>(0));
+ return std::move(*LiteralUtil::CreateR0<uint64>(0));
case S8:
- return std::move(*Literal::CreateR0<int8>(0));
+ return std::move(*LiteralUtil::CreateR0<int8>(0));
case S32:
- return std::move(*Literal::CreateR0<int32>(0));
+ return std::move(*LiteralUtil::CreateR0<int32>(0));
case S64:
- return std::move(*Literal::CreateR0<int64>(0));
+ return std::move(*LiteralUtil::CreateR0<int64>(0));
case F16:
- return std::move(*Literal::CreateR0<half>(static_cast<half>(0.0f)));
+ return std::move(*LiteralUtil::CreateR0<half>(static_cast<half>(0.0f)));
case BF16:
return std::move(
- *Literal::CreateR0<bfloat16>(static_cast<bfloat16>(0.0f)));
+ *LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(0.0f)));
case F32:
- return std::move(*Literal::CreateR0<float>(0));
+ return std::move(*LiteralUtil::CreateR0<float>(0));
case F64:
- return std::move(*Literal::CreateR0<double>(0));
+ return std::move(*LiteralUtil::CreateR0<double>(0));
case C64:
- return std::move(*Literal::CreateR0<complex64>(0));
+ return std::move(*LiteralUtil::CreateR0<complex64>(0));
case PRED:
- return std::move(*Literal::CreateR0<bool>(false));
+ return std::move(*LiteralUtil::CreateR0<bool>(false));
case S16:
case U16:
LOG(FATAL) << "u16/s16 literals not yet implemented";
@@ -614,33 +144,33 @@ Status Literal::CopySliceFrom(const LiteralSlice& src_literal,
}
}
-/* static */ Literal Literal::One(PrimitiveType primitive_type) {
+/* static */ Literal LiteralUtil::One(PrimitiveType primitive_type) {
switch (primitive_type) {
case U8:
- return std::move(*Literal::CreateR0<uint8>(1));
+ return std::move(*LiteralUtil::CreateR0<uint8>(1));
case U32:
- return std::move(*Literal::CreateR0<uint32>(1));
+ return std::move(*LiteralUtil::CreateR0<uint32>(1));
case U64:
- return std::move(*Literal::CreateR0<uint64>(1));
+ return std::move(*LiteralUtil::CreateR0<uint64>(1));
case S8:
- return std::move(*Literal::CreateR0<int8>(1));
+ return std::move(*LiteralUtil::CreateR0<int8>(1));
case S32:
- return std::move(*Literal::CreateR0<int32>(1));
+ return std::move(*LiteralUtil::CreateR0<int32>(1));
case S64:
- return std::move(*Literal::CreateR0<int64>(1));
+ return std::move(*LiteralUtil::CreateR0<int64>(1));
case F16:
- return std::move(*Literal::CreateR0<half>(static_cast<half>(1.0f)));
+ return std::move(*LiteralUtil::CreateR0<half>(static_cast<half>(1.0f)));
case BF16:
return std::move(
- *Literal::CreateR0<bfloat16>(static_cast<bfloat16>(1.0f)));
+ *LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(1.0f)));
case F32:
- return std::move(*Literal::CreateR0<float>(1));
+ return std::move(*LiteralUtil::CreateR0<float>(1));
case F64:
- return std::move(*Literal::CreateR0<double>(1));
+ return std::move(*LiteralUtil::CreateR0<double>(1));
case C64:
- return std::move(*Literal::CreateR0<complex64>(1));
+ return std::move(*LiteralUtil::CreateR0<complex64>(1));
case PRED:
- return std::move(*Literal::CreateR0<bool>(true));
+ return std::move(*LiteralUtil::CreateR0<bool>(true));
case S16:
case U16:
LOG(FATAL) << "u16/s16 literals not yet implemented";
@@ -653,44 +183,44 @@ Status Literal::CopySliceFrom(const LiteralSlice& src_literal,
}
}
-/* static */ Literal Literal::MinValue(PrimitiveType primitive_type) {
+/* static */ Literal LiteralUtil::MinValue(PrimitiveType primitive_type) {
switch (primitive_type) {
case U8:
return std::move(
- *Literal::CreateR0<uint8>(std::numeric_limits<uint8>::min()));
+ *LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::min()));
case U32:
return std::move(
- *Literal::CreateR0<uint32>(std::numeric_limits<uint32>::min()));
+ *LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::min()));
case U64:
return std::move(
- *Literal::CreateR0<uint64>(std::numeric_limits<uint64>::min()));
+ *LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::min()));
case S8:
return std::move(
- *Literal::CreateR0<int8>(std::numeric_limits<int8>::min()));
+ *LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::min()));
case S32:
return std::move(
- *Literal::CreateR0<int32>(std::numeric_limits<int32>::min()));
+ *LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::min()));
case S64:
return std::move(
- *Literal::CreateR0<int64>(std::numeric_limits<int64>::min()));
+ *LiteralUtil::CreateR0<int64>(std::numeric_limits<int64>::min()));
case F32:
- return std::move(
- *Literal::CreateR0<float>(-std::numeric_limits<float>::infinity()));
+ return std::move(*LiteralUtil::CreateR0<float>(
+ -std::numeric_limits<float>::infinity()));
case F64:
- return std::move(
- *Literal::CreateR0<double>(-std::numeric_limits<double>::infinity()));
+ return std::move(*LiteralUtil::CreateR0<double>(
+ -std::numeric_limits<double>::infinity()));
case C64:
LOG(FATAL) << "C64 element type has no minimum value";
case PRED:
- return std::move(*Literal::CreateR0<bool>(false));
+ return std::move(*LiteralUtil::CreateR0<bool>(false));
case S16:
case U16:
LOG(FATAL) << "u16/s16 literals not yet implemented";
case F16:
- return std::move(*Literal::CreateR0<half>(
+ return std::move(*LiteralUtil::CreateR0<half>(
static_cast<half>(-std::numeric_limits<float>::infinity())));
case BF16:
- return std::move(*Literal::CreateR0<bfloat16>(
+ return std::move(*LiteralUtil::CreateR0<bfloat16>(
static_cast<bfloat16>(-std::numeric_limits<float>::infinity())));
case TUPLE:
LOG(FATAL) << "tuple element type has no minimum value";
@@ -701,42 +231,42 @@ Status Literal::CopySliceFrom(const LiteralSlice& src_literal,
}
}
-/* static */ Literal Literal::MaxValue(PrimitiveType primitive_type) {
+/* static */ Literal LiteralUtil::MaxValue(PrimitiveType primitive_type) {
switch (primitive_type) {
case U8:
return std::move(
- *Literal::CreateR0<uint8>(std::numeric_limits<uint8>::max()));
+ *LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::max()));
case U32:
return std::move(
- *Literal::CreateR0<uint32>(std::numeric_limits<uint32>::max()));
+ *LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::max()));
case U64:
return std::move(
- *Literal::CreateR0<uint64>(std::numeric_limits<uint64>::max()));
+ *LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::max()));
case S8:
return std::move(
- *Literal::CreateR0<int8>(std::numeric_limits<int8>::max()));
+ *LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::max()));
case S32:
return std::move(
- *Literal::CreateR0<int32>(std::numeric_limits<int32>::max()));
+ *LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::max()));
case S64:
return std::move(
- *Literal::CreateR0<int64>(std::numeric_limits<int64>::max()));
+ *LiteralUtil::CreateR0<int64>(std::numeric_limits<int64>::max()));
case F32:
- return std::move(
- *Literal::CreateR0<float>(std::numeric_limits<float>::infinity()));
+ return std::move(*LiteralUtil::CreateR0<float>(
+ std::numeric_limits<float>::infinity()));
case F64:
- return std::move(
- *Literal::CreateR0<double>(std::numeric_limits<double>::infinity()));
+ return std::move(*LiteralUtil::CreateR0<double>(
+ std::numeric_limits<double>::infinity()));
case PRED:
- return std::move(*Literal::CreateR0<bool>(true));
+ return std::move(*LiteralUtil::CreateR0<bool>(true));
case S16:
case U16:
LOG(FATAL) << "u16/s16 literals not yet implemented";
case F16:
- return std::move(*Literal::CreateR0<half>(
+ return std::move(*LiteralUtil::CreateR0<half>(
static_cast<half>(std::numeric_limits<float>::infinity())));
case BF16:
- return std::move(*Literal::CreateR0<bfloat16>(
+ return std::move(*LiteralUtil::CreateR0<bfloat16>(
static_cast<bfloat16>(std::numeric_limits<float>::infinity())));
case TUPLE:
LOG(FATAL) << "tuple element type has no maximum value";
@@ -747,7 +277,7 @@ Status Literal::CopySliceFrom(const LiteralSlice& src_literal,
}
}
-/* static */ std::unique_ptr<Literal> Literal::CreateR1(
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR1(
const tensorflow::core::Bitmap& values) {
auto literal = MakeUnique<Literal>(
ShapeUtil::MakeShape(PRED, {static_cast<int64>(values.bits())}));
@@ -755,17 +285,7 @@ Status Literal::CopySliceFrom(const LiteralSlice& src_literal,
return literal;
}
-void Literal::PopulateR1(const tensorflow::core::Bitmap& values) {
- CHECK(ShapeUtil::IsArray(shape()));
- CHECK_EQ(ShapeUtil::Rank(shape()), 1);
- CHECK_EQ(element_count(), values.bits());
- CHECK_EQ(shape().element_type(), PRED);
- for (int64 i = 0; i < static_cast<int64>(values.bits()); ++i) {
- Set({i}, values.get(i));
- }
-}
-
-/* static */ std::unique_ptr<Literal> Literal::CreateR1U8(
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR1U8(
tensorflow::StringPiece value) {
auto literal = MakeUnique<Literal>(
ShapeUtil::MakeShape(U8, {static_cast<int64>(value.size())}));
@@ -775,116 +295,13 @@ void Literal::PopulateR1(const tensorflow::core::Bitmap& values) {
return literal;
}
-/* static */ std::unique_ptr<Literal> Literal::CreateR2F32Linspace(float from,
- float to,
- int64 rows,
- int64 cols) {
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2F32Linspace(
+ float from, float to, int64 rows, int64 cols) {
auto value = MakeLinspaceArray2D(from, to, rows, cols);
return CreateR2FromArray2D(*value);
}
-std::unique_ptr<Literal> LiteralBase::Relayout(
- const Layout& new_layout, const ShapeIndex& shape_index) const {
- // Create new shape with 'new_layout' set at the given shape index.
- Shape new_shape = shape();
- Shape* subshape = ShapeUtil::GetMutableSubshape(&new_shape, shape_index);
- TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(new_layout, *subshape));
- *subshape->mutable_layout() = new_layout;
- auto result = MakeUnique<Literal>(new_shape);
- TF_CHECK_OK(result->CopyFrom(*this));
- return result;
-}
-
-std::unique_ptr<Literal> LiteralBase::Relayout(
- const Shape& shape_with_layout) const {
- CHECK(ShapeUtil::Compatible(shape_with_layout, shape()))
- << "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout)
- << " not compatible with literal shape "
- << ShapeUtil::HumanString(shape());
- std::unique_ptr<Literal> result = CreateFromShape(shape_with_layout);
- ShapeUtil::ForEachSubshape(
- result->shape(),
- [this, &result](const Shape& subshape, const ShapeIndex& index) {
- if (ShapeUtil::IsArray(subshape)) {
- TF_CHECK_OK(result->CopyFrom(*this,
- /*dest_shape_index=*/index,
- /*src_shape_index=*/index));
- }
- });
- return result;
-}
-
-StatusOr<std::unique_ptr<Literal>> LiteralBase::Broadcast(
- const Shape& result_shape,
- tensorflow::gtl::ArraySlice<int64> dimensions) const {
- if (!ShapeUtil::IsArray(shape())) {
- return InvalidArgument("Broadcast only supports arrays.");
- }
-
- for (int64 i = 0; i < dimensions.size(); i++) {
- TF_RET_CHECK(shape().dimensions(i) ==
- result_shape.dimensions(dimensions[i]));
- }
-
- std::unique_ptr<Literal> result = MakeUnique<Literal>(result_shape);
-
- // scratch_source_index is temporary storage space for the computed index into
- // the input literal. We put it here to avoid allocating an std::vector in
- // every iteration of ShapeUtil::ForEachIndex.
- std::vector<int64> scratch_source_index(shape().dimensions_size());
-
- char* dest_data = static_cast<char*>(result->untyped_data());
- const char* source_data = static_cast<const char*>(untyped_data());
- const int64 primitive_size =
- ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
-
- ShapeUtil::ForEachIndex(
- result_shape, [&](tensorflow::gtl::ArraySlice<int64> output_index) {
- for (int64 i = 0; i < dimensions.size(); ++i) {
- scratch_source_index[i] = output_index[dimensions[i]];
- }
- int64 dest_index = IndexUtil::MultidimensionalIndexToLinearIndex(
- result_shape, output_index);
- int64 source_index = IndexUtil::MultidimensionalIndexToLinearIndex(
- shape(), scratch_source_index);
- memcpy(dest_data + primitive_size * dest_index,
- source_data + primitive_size * source_index, primitive_size);
- return true;
- });
-
- return std::move(result);
-}
-
-StatusOr<std::unique_ptr<Literal>> LiteralBase::Reshape(
- tensorflow::gtl::ArraySlice<int64> dimensions) const {
- if (!ShapeUtil::IsArray(shape())) {
- return InvalidArgument("Reshape does not support tuples.");
- }
- std::unique_ptr<Literal> output;
- if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) {
- output =
- Relayout(LayoutUtil::GetDefaultLayoutForRank(ShapeUtil::Rank(shape())));
- } else {
- output = CloneToUnique();
- }
- // Because the layout is monotonic, we can simply reuse the same sequence of
- // values without changing their order.
- *output->mutable_shape_do_not_use() =
- ShapeUtil::MakeShape(shape().element_type(), dimensions);
-
- int64 elements_before = ShapeUtil::ElementsIn(shape());
- int64 elements_after = ShapeUtil::ElementsIn(output->shape());
- if (elements_before != elements_after) {
- return InvalidArgument(
- "Shapes before and after Literal::Reshape have different numbers "
- "of elements: %s vs %s.",
- ShapeUtil::HumanString(shape()).c_str(),
- ShapeUtil::HumanString(output->shape()).c_str());
- }
- return std::move(output);
-}
-
-/* static */ std::unique_ptr<Literal> Literal::ReshapeSlice(
+/* static */ std::unique_ptr<Literal> LiteralUtil::ReshapeSlice(
tensorflow::gtl::ArraySlice<int64> new_dimensions,
tensorflow::gtl::ArraySlice<int64> minor_to_major,
const LiteralSlice& literal) {
@@ -956,575 +373,64 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::Reshape(
return new_literal;
}
-std::unique_ptr<Literal> LiteralBase::Transpose(
- tensorflow::gtl::ArraySlice<int64> permutation) const {
- CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose";
- CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape())))
- << "Given permutation is not a permutation of dimension numbers";
- // To transpose the array, we just permute the dimensions and layout, and
- // do a straight memory copy of the raw data set.
- // This is considerably faster than iterating over every array element using
- // the EachCell<>() and Set<>() APIs.
- std::vector<int64> inverse_permutation = InversePermutation(permutation);
- Shape permuted_shape =
- ShapeUtil::PermuteDimensions(inverse_permutation, shape());
- // Replace the layout with one affine to this shape, such that a
- // transpose operation can be performed by leaving the flat values
- // representation intact.
- // For example, consider the shape F32[11,8]{1,0} under a {1,0} permutation.
- // The shape with affine layout resulting from that operation will be
- // F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), the
- // most minor.
- //
- // Essentially, given MinMaj(Di) the position of the Di dimension within the
- // minor to major vector, and given T(Di) the index that the original Di
- // dimension has within the transposed array, a layout is affine if
- // MinMaj(Di) == TMinMaj(T(Di)), with TMinMaj() being the minor to major
- // vector of the affine layout.
- CHECK(LayoutUtil::IsDenseArray(permuted_shape));
- Layout* layout = permuted_shape.mutable_layout();
- layout->clear_minor_to_major();
- for (auto index : LayoutUtil::MinorToMajor(shape())) {
- layout->add_minor_to_major(inverse_permutation[index]);
- }
- auto new_literal = MakeUnique<Literal>(permuted_shape);
- DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal->shape()),
- ShapeUtil::ByteSizeOf(shape()));
- std::memcpy(new_literal->untyped_data(), untyped_data(), size_bytes());
- return new_literal;
-}
-
-template <typename NativeT>
-std::unique_ptr<Literal> LiteralBase::SliceInternal(
- const Shape& result_shape,
- tensorflow::gtl::ArraySlice<int64> start_indices) const {
- auto result_literal = MakeUnique<Literal>(result_shape);
- DimensionVector new_indices(ShapeUtil::Rank(result_shape));
- result_literal->EachCell<NativeT>(
- [&](tensorflow::gtl::ArraySlice<int64> indices, NativeT /*value*/) {
- for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) {
- new_indices[i] = indices[i] + start_indices[i];
- }
- NativeT value = Get<NativeT>(new_indices);
- result_literal->Set<NativeT>(indices, value);
- });
- return result_literal;
-}
-
-std::unique_ptr<Literal> LiteralBase::Slice(
- tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices) const {
- CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice";
-
- DimensionVector result_dimensions;
- for (int64 dnum = 0; dnum < ShapeUtil::Rank(shape()); ++dnum) {
- CHECK_GE(start_indices[dnum], 0);
- CHECK_LE(limit_indices[dnum], shape().dimensions(dnum))
- << "dnum = " << dnum;
- int64 dimension = limit_indices[dnum] - start_indices[dnum];
- CHECK_GE(dimension, 0) << "dnum = " << dnum;
- result_dimensions.push_back(dimension);
- }
- const auto result_shape =
- ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions,
- LayoutUtil::MinorToMajor(shape()));
- switch (result_shape.element_type()) {
- case F32:
- return SliceInternal<float>(result_shape, start_indices);
- case BF16:
- return SliceInternal<bfloat16>(result_shape, start_indices);
- case C64:
- return SliceInternal<complex64>(result_shape, start_indices);
- case S32:
- return SliceInternal<int32>(result_shape, start_indices);
- case U32:
- return SliceInternal<uint32>(result_shape, start_indices);
- default:
- LOG(FATAL) << "not yet implemented: "
- << PrimitiveType_Name(result_shape.element_type());
- }
-}
-
-Literal LiteralBase::Clone() const {
- Literal result(shape());
- TF_CHECK_OK(result.CopyFrom(*this));
- return result;
-}
-
-std::unique_ptr<Literal> LiteralBase::CloneToUnique() const {
- auto result = MakeUnique<Literal>(shape());
- TF_CHECK_OK(result->CopyFrom(*this));
- return result;
-}
-
-string LiteralBase::GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index,
- const ShapeIndex& shape_index) const {
- const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
- CHECK(LayoutUtil::IsDenseArray(subshape));
- switch (subshape.element_type()) {
- case PRED:
- return Get<bool>(multi_index, shape_index) ? "true" : "false";
- case S8:
- return StrCat(Get<int8>(multi_index, shape_index));
- case S16:
- return StrCat(Get<int16>(multi_index, shape_index));
- case S32:
- return StrCat(Get<int32>(multi_index, shape_index));
- case S64:
- return StrCat(Get<int64>(multi_index, shape_index));
- case U8:
- return StrCat(Get<uint8>(multi_index, shape_index));
- case U16:
- return StrCat(Get<uint16>(multi_index, shape_index));
- case U32:
- return StrCat(Get<uint32>(multi_index, shape_index));
- case U64:
- return StrCat(Get<uint64>(multi_index, shape_index));
- case F16:
- return StrCat(static_cast<float>(Get<half>(multi_index, shape_index)));
- case F32:
- return StrCat(Get<float>(multi_index, shape_index));
- case BF16:
- return StrCat(
- static_cast<float>(Get<bfloat16>(multi_index, shape_index)));
- case F64:
- return StrCat(Get<double>(multi_index, shape_index));
- case C64: {
- complex64 c = Get<complex64>(multi_index, shape_index);
- return StrCat("(", c.real(), ", ", c.imag(), ")");
- }
- default:
- LOG(FATAL) << PrimitiveType_Name(subshape.element_type());
- }
-}
-
-string LiteralBase::GetSparseElementAsString(
- int64 sparse_element_number, const ShapeIndex& shape_index) const {
- const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
- CHECK(LayoutUtil::IsSparseArray(subshape));
- switch (subshape.element_type()) {
- case PRED:
- return GetSparseElement<bool>(sparse_element_number, shape_index)
- ? "true"
- : "false";
- case S8:
- return StrCat(GetSparseElement<int8>(sparse_element_number, shape_index));
- case S16:
- return StrCat(
- GetSparseElement<int16>(sparse_element_number, shape_index));
- case S32:
- return StrCat(
- GetSparseElement<int32>(sparse_element_number, shape_index));
- case S64:
- return StrCat(
- GetSparseElement<int64>(sparse_element_number, shape_index));
- case U8:
- return StrCat(
- GetSparseElement<uint8>(sparse_element_number, shape_index));
- case U16:
- return StrCat(
- GetSparseElement<uint16>(sparse_element_number, shape_index));
- case U32:
- return StrCat(
- GetSparseElement<uint32>(sparse_element_number, shape_index));
- case U64:
- return StrCat(
- GetSparseElement<uint64>(sparse_element_number, shape_index));
- case F16:
- return StrCat(static_cast<float>(
- GetSparseElement<half>(sparse_element_number, shape_index)));
- case F32:
- return StrCat(
- GetSparseElement<float>(sparse_element_number, shape_index));
- case BF16:
- return StrCat(static_cast<float>(
- GetSparseElement<bfloat16>(sparse_element_number, shape_index)));
- case F64:
- return StrCat(
- GetSparseElement<double>(sparse_element_number, shape_index));
- case C64: {
- complex64 c =
- GetSparseElement<complex64>(sparse_element_number, shape_index);
- return StrCat("(", c.real(), ", ", c.imag(), ")");
- }
- default:
- LOG(FATAL) << "Invalid element type for sparse arrays: "
- << PrimitiveType_Name(subshape.element_type());
- }
-}
-
-StatusOr<int64> LiteralBase::GetIntegralAsS64(
- tensorflow::gtl::ArraySlice<int64> multi_index) const {
- CHECK(LayoutUtil::IsDenseArray(shape()));
- switch (shape().element_type()) {
- case PRED:
- return Get<bool>(multi_index);
- case U8:
- return Get<uint8>(multi_index);
- case S32:
- return Get<int32>(multi_index);
- case S64:
- return Get<int64>(multi_index);
- case U32:
- return Get<uint32>(multi_index);
- case U64:
- return Get<uint64>(multi_index);
- default:
- return FailedPrecondition(
- "Array element type is not integral: %s",
- PrimitiveType_Name(shape().element_type()).c_str());
- }
-}
-
-size_t LiteralBase::Hash() const {
- using tensorflow::Hash64;
- using tensorflow::Hash64Combine;
-
- size_t hash_value = ShapeUtil::Hash(shape());
-
- ShapeUtil::ForEachSubshape(
- shape(), [&](const Shape& subshape, const ShapeIndex& index) {
- if (!ShapeUtil::IsArray(subshape)) {
- return;
- }
-
- CHECK(LayoutUtil::IsDense(subshape.layout()));
- hash_value = Hash64Combine(
- hash_value, Hash64(static_cast<const char*>(untyped_data(index)),
- size_bytes(index)));
- });
-
- return hash_value;
-}
-
-Status Literal::SetIntegralAsS64(tensorflow::gtl::ArraySlice<int64> multi_index,
- int64 value) {
- CHECK(LayoutUtil::IsDenseArray(shape()));
- switch (shape().element_type()) {
- case PRED:
- Set<bool>(multi_index, value);
- break;
- case U8:
- Set<uint8>(multi_index, value);
- break;
- case S32:
- Set<int32>(multi_index, value);
- break;
- case S64:
- Set<int64>(multi_index, value);
- break;
- case U32:
- Set<uint32>(multi_index, value);
- break;
- case U64:
- Set<uint64>(multi_index, value);
- break;
- default:
- return FailedPrecondition(
- "Array element type is not integral: %s",
- PrimitiveType_Name(shape().element_type()).c_str());
- }
- return Status::OK();
-}
-
-tensorflow::gtl::ArraySlice<int64> LiteralBase::GetSparseIndex(
- int64 sparse_element_number, const ShapeIndex& shape_index) const {
- const Piece& p = piece(shape_index);
- CHECK_GE(sparse_element_number, 0);
- CHECK_LT(sparse_element_number, p.sparse_indices()->index_count());
- return p.sparse_indices()->At(sparse_element_number);
-}
-
-void Literal::SortSparseElements(const ShapeIndex& shape_index) {
- piece(shape_index).SortSparseElements();
-}
-
-Literal LiteralBase::GetFirstScalarLiteral() const {
- CHECK(ShapeUtil::IsArray(shape()));
- CHECK_GT(ShapeUtil::ElementsIn(shape()), 0);
- switch (shape().element_type()) {
+/* static */ Literal LiteralUtil::GetFirstScalarLiteral(
+ const LiteralSlice& literal) {
+ CHECK(ShapeUtil::IsArray(literal.shape()));
+ CHECK_GT(ShapeUtil::ElementsIn(literal.shape()), 0);
+ switch (literal.shape().element_type()) {
case PRED:
- return std::move(*Literal::CreateR0<bool>(GetFirstElement<bool>()));
+ return std::move(
+ *LiteralUtil::CreateR0<bool>(literal.GetFirstElement<bool>()));
// 8 bit types.
case S8:
- return std::move(*Literal::CreateR0<int8>(GetFirstElement<int8>()));
+ return std::move(
+ *LiteralUtil::CreateR0<int8>(literal.GetFirstElement<int8>()));
case U8:
- return std::move(*Literal::CreateR0<uint8>(GetFirstElement<uint8>()));
+ return std::move(
+ *LiteralUtil::CreateR0<uint8>(literal.GetFirstElement<uint8>()));
// 16 bit types.
case BF16:
- return std::move(
- *Literal::CreateR0<bfloat16>(GetFirstElement<bfloat16>()));
+ return std::move(*LiteralUtil::CreateR0<bfloat16>(
+ literal.GetFirstElement<bfloat16>()));
case F16:
- return std::move(*Literal::CreateR0<half>(GetFirstElement<half>()));
+ return std::move(
+ *LiteralUtil::CreateR0<half>(literal.GetFirstElement<half>()));
case S16:
- return std::move(*Literal::CreateR0<int16>(GetFirstElement<int16>()));
+ return std::move(
+ *LiteralUtil::CreateR0<int16>(literal.GetFirstElement<int16>()));
case U16:
- return std::move(*Literal::CreateR0<uint16>(GetFirstElement<uint16>()));
+ return std::move(
+ *LiteralUtil::CreateR0<uint16>(literal.GetFirstElement<uint16>()));
// 32 bit types.
case F32:
- return std::move(*Literal::CreateR0<float>(GetFirstElement<float>()));
+ return std::move(
+ *LiteralUtil::CreateR0<float>(literal.GetFirstElement<float>()));
case S32:
- return std::move(*Literal::CreateR0<int32>(GetFirstElement<int32>()));
+ return std::move(
+ *LiteralUtil::CreateR0<int32>(literal.GetFirstElement<int32>()));
case U32:
- return std::move(*Literal::CreateR0<uint32>(GetFirstElement<uint32>()));
+ return std::move(
+ *LiteralUtil::CreateR0<uint32>(literal.GetFirstElement<uint32>()));
// 64 bit types.
case C64:
- return std::move(
- *Literal::CreateR0<complex64>(GetFirstElement<complex64>()));
+ return std::move(*LiteralUtil::CreateR0<complex64>(
+ literal.GetFirstElement<complex64>()));
case F64:
- return std::move(*Literal::CreateR0<double>(GetFirstElement<double>()));
- case S64:
- return std::move(*Literal::CreateR0<int64>(GetFirstElement<int64>()));
- case U64:
- return std::move(*Literal::CreateR0<uint64>(GetFirstElement<uint64>()));
- default:
- LOG(FATAL) << "Unhandled primitive type " << shape().element_type();
- }
-}
-
-void LiteralBase::Piece::SortSparseElements() {
- switch (subshape().element_type()) {
- case PRED:
- SortSparseElementsInternal<bool>();
- break;
- case S8:
- SortSparseElementsInternal<int8>();
- break;
- case U8:
- SortSparseElementsInternal<uint8>();
- break;
- case S16:
- SortSparseElementsInternal<int16>();
- break;
- case U16:
- SortSparseElementsInternal<uint16>();
- break;
- case S32:
- SortSparseElementsInternal<int32>();
- break;
- case U32:
- SortSparseElementsInternal<uint32>();
- break;
+ return std::move(
+ *LiteralUtil::CreateR0<double>(literal.GetFirstElement<double>()));
case S64:
- SortSparseElementsInternal<int64>();
- break;
+ return std::move(
+ *LiteralUtil::CreateR0<int64>(literal.GetFirstElement<int64>()));
case U64:
- SortSparseElementsInternal<uint64>();
- break;
- case F32:
- SortSparseElementsInternal<float>();
- break;
- case F64:
- SortSparseElementsInternal<double>();
- break;
- case C64:
- SortSparseElementsInternal<complex64>();
- break;
- case F16:
- SortSparseElementsInternal<half>();
- break;
- case BF16:
- SortSparseElementsInternal<bfloat16>();
- break;
+ return std::move(
+ *LiteralUtil::CreateR0<uint64>(literal.GetFirstElement<uint64>()));
default:
- LOG(FATAL) << "Element type not valid for sparse array: "
- << PrimitiveType_Name(subshape().element_type());
- }
-}
-
-template <typename NativeT>
-void LiteralBase::Piece::SortSparseElementsInternal() {
- CHECK(LayoutUtil::IsSparseArray(subshape()));
- int64 num_elements = sparse_indices()->index_count();
- auto values = data<NativeT>();
- CHECK_LE(num_elements, values.size());
- sparse_indices()->SortWithValues(
- tensorflow::gtl::MutableArraySlice<NativeT>(values.data(), num_elements));
-}
-
-namespace {
-
-void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
- bool print_layout, std::vector<string>* pieces) {
- const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
- CHECK(LayoutUtil::HasLayout(literal.shape()));
- CHECK(LayoutUtil::HasLayout(subshape));
-
- auto shape_to_string = [print_layout](const Shape& shape) {
- if (print_layout) {
- return ShapeUtil::HumanStringWithLayout(shape);
- } else {
- return ShapeUtil::HumanString(shape);
- }
- };
-
- // TODO(b/32894291): refactor this code to reduce code duplication.
- if (ShapeUtil::IsTuple(subshape)) {
- pieces->push_back(shape_to_string(subshape));
- pieces->push_back(" (\n");
- std::vector<string> tuple_pieces;
- for (int i = 0; i < ShapeUtil::TupleElementCount(subshape); ++i) {
- ShapeIndex element_index = shape_index;
- element_index.push_back(i);
- std::vector<string> element_pieces;
- ToStringHelper(literal, element_index, print_layout, &element_pieces);
- tuple_pieces.push_back(tensorflow::str_util::Join(element_pieces, ""));
- }
- pieces->push_back(tensorflow::str_util::Join(tuple_pieces, ",\n"));
- pieces->push_back("\n)");
- return;
- }
-
- if (ShapeUtil::IsToken(subshape)) {
- pieces->push_back("token");
- return;
- }
-
- if (LayoutUtil::IsSparseArray(subshape)) {
- pieces->push_back(shape_to_string(subshape));
- pieces->push_back("{");
- int64 rank = ShapeUtil::Rank(subshape);
- int64 num_elements = literal.sparse_element_count();
- for (int64 i = 0; i < num_elements; ++i) {
- if (i > 0) {
- pieces->push_back(", ");
- }
- if (rank == 1) {
- pieces->push_back(StrCat(literal.GetSparseIndex(i)[0]));
- pieces->push_back(": ");
- } else {
- pieces->push_back("[");
- pieces->push_back(
- tensorflow::str_util::Join(literal.GetSparseIndex(i), ", "));
- pieces->push_back("]: ");
- }
- pieces->push_back(literal.GetSparseElementAsString(i));
- }
- pieces->push_back("}");
- return;
- }
-
- CHECK(LayoutUtil::IsDenseArray(subshape));
-
- auto element_to_string =
- [&](tensorflow::gtl::ArraySlice<int64> indices) -> string {
- PrimitiveType element_type = subshape.element_type();
- if (element_type == PRED) {
- // We display predicates in a densely packed form.
- return literal.Get<bool>(indices, shape_index) ? "1" : "0";
- }
- return ((!indices.empty() && indices.back() > 0) ? ", " : "") +
- literal.GetAsString(indices, shape_index);
- };
-
- if (ShapeUtil::Rank(subshape) == 0) {
- pieces->push_back(literal.GetAsString({}, shape_index));
- } else if (ShapeUtil::Rank(subshape) == 1) {
- pieces->push_back("{");
- for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) {
- pieces->push_back(element_to_string({i0}));
- }
- pieces->push_back("}");
- } else if (ShapeUtil::Rank(subshape) == 2) {
- pieces->push_back(shape_to_string(subshape));
- pieces->push_back(" {\n");
- for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) {
- pieces->push_back(" { ");
- for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) {
- pieces->push_back(element_to_string({i0, i1}));
- }
- pieces->push_back(" ");
- pieces->push_back(i0 == subshape.dimensions(0) - 1 ? "}\n" : "},\n");
- }
- pieces->push_back("}");
- } else if (ShapeUtil::Rank(subshape) == 3) {
- pieces->push_back(shape_to_string(subshape));
- pieces->push_back(" {\n");
- for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) {
- pieces->push_back(i0 > 0 ? ",\n{" : "{");
- for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) {
- pieces->push_back(i1 > 0 ? ",\n { " : " { ");
- for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) {
- pieces->push_back(element_to_string({i0, i1, i2}));
- }
- pieces->push_back(" }");
- }
- pieces->push_back(" }");
- }
- pieces->push_back("\n}");
- } else if (ShapeUtil::Rank(subshape) == 4) {
- pieces->push_back(shape_to_string(subshape));
- pieces->push_back(" {\n");
- for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) {
- pieces->push_back(Printf(" { /*i0=%lld*/\n", i0));
- for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) {
- pieces->push_back(Printf(" { /*i1=%lld*/\n", i1));
- for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) {
- pieces->push_back(" {");
- for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) {
- pieces->push_back(element_to_string({i0, i1, i2, i3}));
- }
- pieces->push_back(i2 == subshape.dimensions(2) - 1 ? "}\n" : "},\n");
- }
- pieces->push_back(i1 == subshape.dimensions(1) - 1 ? " }\n"
- : " },\n");
- }
- pieces->push_back(i0 == subshape.dimensions(0) - 1 ? " }\n" : " },\n");
- }
- pieces->push_back("}");
- } else if (ShapeUtil::Rank(subshape) == 5) {
- pieces->push_back(shape_to_string(subshape));
- pieces->push_back(" {\n");
- for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) {
- pieces->push_back(Printf(" { /*i0=%lld*/\n", i0));
- for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) {
- pieces->push_back(Printf(" { /*i1=%lld*/\n", i1));
- for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) {
- pieces->push_back(Printf(" { /*i2=%lld*/\n", i2));
- for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) {
- pieces->push_back(" {");
- for (int64 i4 = 0; i4 < subshape.dimensions(4); ++i4) {
- pieces->push_back(element_to_string({i0, i1, i2, i3, i4}));
- }
- pieces->push_back(i3 == subshape.dimensions(3) - 1 ? "}\n"
- : "},\n");
- }
- pieces->push_back(i2 == subshape.dimensions(2) - 1 ? " }\n"
- : " },\n");
- }
- pieces->push_back(i1 == subshape.dimensions(1) - 1 ? " }\n"
- : " },\n");
- }
- pieces->push_back(i0 == subshape.dimensions(0) - 1 ? " }\n" : " },\n");
- }
- pieces->push_back("}");
- } else {
- pieces->push_back(shape_to_string(subshape));
- pieces->push_back(" {");
- literal.EachCellAsString(
- [&](tensorflow::gtl::ArraySlice<int64> indices, const string& value) {
- pieces->push_back(" ");
- pieces->push_back(value);
- });
- pieces->push_back("}");
+ LOG(FATAL) << "Unhandled primitive type "
+ << literal.shape().element_type();
}
}
-} // namespace
-
-int64 LiteralBase::sparse_element_count() const {
- CHECK(LayoutUtil::IsSparseArray(shape()));
- return sparse_indices()->index_count();
-}
-
-string LiteralBase::ToString(bool print_layout) const {
- std::vector<string> pieces;
- CHECK(LayoutUtil::HasLayout(this->shape()));
- ToStringHelper(*this, {}, print_layout, &pieces);
- return tensorflow::str_util::Join(pieces, "");
-}
-
-/* static */ std::unique_ptr<Literal> Literal::MakeTuple(
+/* static */ std::unique_ptr<Literal> LiteralUtil::MakeTuple(
tensorflow::gtl::ArraySlice<const Literal*> elements) {
std::vector<Shape> element_shapes;
for (const auto* element : elements) {
@@ -1537,7 +443,7 @@ string LiteralBase::ToString(bool print_layout) const {
return literal;
}
-/* static */ std::unique_ptr<Literal> Literal::MakeTupleFromSlices(
+/* static */ std::unique_ptr<Literal> LiteralUtil::MakeTupleFromSlices(
tensorflow::gtl::ArraySlice<LiteralSlice> elements) {
std::vector<Shape> element_shapes;
for (const auto& element : elements) {
@@ -1550,7 +456,7 @@ string LiteralBase::ToString(bool print_layout) const {
return literal;
}
-/* static */ std::unique_ptr<Literal> Literal::MakeTupleOwned(
+/* static */ std::unique_ptr<Literal> LiteralUtil::MakeTupleOwned(
std::vector<std::unique_ptr<Literal>> elements) {
std::vector<Shape> element_shapes;
element_shapes.reserve(elements.size());
@@ -1565,822 +471,9 @@ string LiteralBase::ToString(bool print_layout) const {
return literal;
}
-void LiteralBase::EachCellAsString(
- const std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
- const string& value)>& per_cell) const {
- if (ShapeUtil::IsZeroElementArray(shape())) {
- return;
- }
- std::vector<int64> indices = IndexUtil::LinearIndexToMultidimensionalIndex(
- shape(), /*linear_index=*/0);
- do {
- per_cell(indices, GetAsString(indices));
- } while (IndexUtil::BumpIndices(shape(), &indices));
-}
-
-namespace {
-template <typename NativeSrcT, typename NativeDestT, typename ConverterType>
-std::unique_ptr<Literal> ConvertBetweenNativeTypesWithConverter(
- const LiteralBase& src_literal, const ConverterType& converter) {
- CHECK(ShapeUtil::IsArray(src_literal.shape()));
- auto result_literal = MakeUnique<Literal>(ShapeUtil::ChangeElementType(
- src_literal.shape(),
- primitive_util::NativeToPrimitiveType<NativeDestT>()));
- auto src_data = src_literal.data<NativeSrcT>();
- auto dest_data = result_literal->template data<NativeDestT>();
- int64 num_elements = src_literal.element_count();
-
- for (int64 i = 0; i < num_elements; ++i) {
- dest_data[i] = converter(src_data[i]);
- }
- return result_literal;
-}
-
-template <typename NativeSrcT, typename NativeDestT>
-std::unique_ptr<Literal> ConvertBetweenNativeTypes(
- const LiteralBase& src_literal) {
- auto converter = [](NativeSrcT src) { return static_cast<NativeDestT>(src); };
- return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
- src_literal, converter);
-}
-
-template <typename NativeSrcT, typename NativeDestT>
-typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT)),
- std::unique_ptr<Literal>>::type
-BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
- auto converter = [](NativeSrcT src) {
- return tensorflow::bit_cast<NativeDestT>(src);
- };
- return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
- src_literal, converter);
-}
-
-// This template specialization is here to make the compiler happy. bit_cast has
-// a static check that the types are the same size. This specialization should
-// never be used because the source and destination types are checked for
-// identical sizes higher up.
-template <typename NativeSrcT, typename NativeDestT>
-typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)),
- std::unique_ptr<Literal>>::type
-BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
- LOG(FATAL) << "Invalid bitcast between types of different sizes.";
-}
-
-template <PrimitiveType primitive_src_type>
-std::unique_ptr<Literal> ConvertToC64(const LiteralBase& src_literal) {
- CHECK(ShapeUtil::IsArray(src_literal.shape()));
- auto result_literal = MakeUnique<Literal>(
- ShapeUtil::ChangeElementType(src_literal.shape(), C64));
- using NativeSrcT =
- typename primitive_util::PrimitiveTypeToNative<primitive_src_type>::type;
- tensorflow::gtl::ArraySlice<NativeSrcT> src_data =
- src_literal.data<NativeSrcT>();
- tensorflow::gtl::MutableArraySlice<complex64> dest_data =
- result_literal->data<complex64>();
- int64 num_elements = src_literal.element_count();
- for (int64 i = 0; i < num_elements; ++i) {
- dest_data[i] = complex64(static_cast<float>(src_data[i]), 0);
- }
- return result_literal;
-}
-
-template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type>
-std::unique_ptr<Literal> ConvertIfTypesMatch(const LiteralBase& src_literal,
- bool bitcast) {
- CHECK_EQ(primitive_src_type, src_literal.shape().element_type());
- if (bitcast) {
- return BitcastBetweenNativeTypes<
- typename primitive_util::PrimitiveTypeToNative<
- primitive_src_type>::type,
- typename primitive_util::PrimitiveTypeToNative<
- primitive_dest_type>::type>(src_literal);
- } else {
- return ConvertBetweenNativeTypes<
- typename primitive_util::PrimitiveTypeToNative<
- primitive_src_type>::type,
- typename primitive_util::PrimitiveTypeToNative<
- primitive_dest_type>::type>(src_literal);
- }
-}
-
-template <PrimitiveType primitive_src_type>
-StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches(
- const LiteralBase& src_literal, PrimitiveType primitive_dest_type,
- bool bitcast) {
- switch (primitive_dest_type) {
-#define CONVERT_IF_TYPES_MATCH(type) \
- case (type): \
- return ConvertIfTypesMatch<primitive_src_type, (type)>(src_literal, \
- bitcast);
- CONVERT_IF_TYPES_MATCH(PRED)
- CONVERT_IF_TYPES_MATCH(S8)
- CONVERT_IF_TYPES_MATCH(S32)
- CONVERT_IF_TYPES_MATCH(S64)
- CONVERT_IF_TYPES_MATCH(U8)
- CONVERT_IF_TYPES_MATCH(U32)
- CONVERT_IF_TYPES_MATCH(U64)
- CONVERT_IF_TYPES_MATCH(F16)
- CONVERT_IF_TYPES_MATCH(F32)
- CONVERT_IF_TYPES_MATCH(F64)
- CONVERT_IF_TYPES_MATCH(BF16)
-#undef CONVERT_IF_TYPES_MATCH
- case C64:
- if (!bitcast) {
- return ConvertToC64<primitive_src_type>(src_literal);
- }
- break;
- // Other types are not yet supported.
- default:
- break;
- }
- return Unimplemented(
- "Converting from type %s to type %s is not implemented.",
- PrimitiveType_Name(src_literal.shape().element_type()).c_str(),
- PrimitiveType_Name(primitive_dest_type).c_str());
-}
-
-StatusOr<std::unique_ptr<Literal>> ConvertSwitch(
- const LiteralBase& literal, PrimitiveType primitive_dest_type,
- bool bitcast) {
- TF_RET_CHECK(ShapeUtil::IsArray(literal.shape()));
- if (literal.shape().element_type() == primitive_dest_type) {
- return literal.CloneToUnique();
- }
- switch (literal.shape().element_type()) {
-#define CONVERT_IF_DEST_TYPE_MATCHES(type) \
- case (type): \
- return ConvertIfDestTypeMatches<(type)>(literal, primitive_dest_type, \
- bitcast);
- CONVERT_IF_DEST_TYPE_MATCHES(PRED)
- CONVERT_IF_DEST_TYPE_MATCHES(S8)
- CONVERT_IF_DEST_TYPE_MATCHES(S32)
- CONVERT_IF_DEST_TYPE_MATCHES(S64)
- CONVERT_IF_DEST_TYPE_MATCHES(U8)
- CONVERT_IF_DEST_TYPE_MATCHES(U32)
- CONVERT_IF_DEST_TYPE_MATCHES(U64)
- CONVERT_IF_DEST_TYPE_MATCHES(F16)
- CONVERT_IF_DEST_TYPE_MATCHES(F32)
- CONVERT_IF_DEST_TYPE_MATCHES(F64)
- CONVERT_IF_DEST_TYPE_MATCHES(BF16)
-#undef CONVERT_IF_DEST_TYPE_MATCHES
- // Other types are not yet supported.
- default:
- return Unimplemented(
- "%s from type %s to type %s is not implemented.",
- (bitcast ? "Bitcast converting" : "Converting"),
- PrimitiveType_Name(literal.shape().element_type()).c_str(),
- PrimitiveType_Name(primitive_dest_type).c_str());
- }
-}
-
-} // namespace
-
-StatusOr<std::unique_ptr<Literal>> LiteralBase::Convert(
- PrimitiveType primitive_dest_type) const {
- return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false);
-}
-
-StatusOr<std::unique_ptr<Literal>> LiteralBase::BitcastConvert(
- PrimitiveType primitive_dest_type) const {
- if (primitive_util::BitWidth(shape().element_type()) !=
- primitive_util::BitWidth(primitive_dest_type)) {
- return InvalidArgument(
- "Cannot bitcast convert from %s to %s, bit widths are different: %d != "
- "%d",
- PrimitiveType_Name(shape().element_type()).c_str(),
- PrimitiveType_Name(primitive_dest_type).c_str(),
- primitive_util::BitWidth(shape().element_type()),
- primitive_util::BitWidth(primitive_dest_type));
- }
- return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/true);
-}
-
-StatusOr<std::unique_ptr<Literal>> LiteralBase::ConvertToShape(
- const Shape& dest_shape, bool round_f32_to_bf16) const {
- if (!ShapeUtil::IsTuple(dest_shape)) {
- if (round_f32_to_bf16 && shape().element_type() == F32 &&
- dest_shape.element_type() == BF16) {
- auto converter = [](float src) {
- return tensorflow::bfloat16::round_to_bfloat16(src);
- };
- return ConvertBetweenNativeTypesWithConverter<float, bfloat16>(*this,
- converter);
- }
- return Convert(dest_shape.element_type());
- }
- std::vector<Literal> elements;
- for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) {
- auto element = LiteralSlice(*this, {i});
- TF_ASSIGN_OR_RETURN(
- auto new_element,
- element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i})));
- elements.push_back(std::move(*new_element));
- }
- auto converted = MakeUnique<Literal>();
- *converted = Literal::MoveIntoTuple(&elements);
- return std::move(converted);
-}
-
-template <typename NativeT>
-bool LiteralBase::Piece::EqualElementsInternal(
- const LiteralBase::Piece& other, std::vector<int64>* multi_index) const {
- if (multi_index->size() == ShapeUtil::Rank(subshape())) {
- return (Get<NativeT>(*multi_index) == other.Get<NativeT>(*multi_index));
- }
- for (int64 i = 0; i < subshape().dimensions(multi_index->size()); ++i) {
- multi_index->push_back(i);
- if (!EqualElementsInternal<NativeT>(other, multi_index)) {
- return false;
- }
- multi_index->pop_back();
- }
- return true;
-}
-
-bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const {
- DCHECK(ShapeUtil::Compatible(subshape(), other.subshape()));
-
- std::vector<int64> multi_index;
- switch (subshape().element_type()) {
- case PRED:
- return EqualElementsInternal<bool>(other, &multi_index);
- case U8:
- return EqualElementsInternal<uint8>(other, &multi_index);
- case S32:
- return EqualElementsInternal<int32>(other, &multi_index);
- case S64:
- return EqualElementsInternal<int64>(other, &multi_index);
- case U32:
- return EqualElementsInternal<uint32>(other, &multi_index);
- case U64:
- return EqualElementsInternal<uint64>(other, &multi_index);
- case F32:
- return EqualElementsInternal<float>(other, &multi_index);
- case F64:
- return EqualElementsInternal<double>(other, &multi_index);
- case F16:
- return EqualElementsInternal<half>(other, &multi_index);
- case BF16:
- return EqualElementsInternal<bfloat16>(other, &multi_index);
- case C64:
- return EqualElementsInternal<complex64>(other, &multi_index);
- default:
- LOG(FATAL) << "Unimplemented: LiteralBase::Piece::EqualElements for type "
- << PrimitiveType_Name(subshape().element_type());
- }
-}
-
-bool LiteralBase::operator==(const LiteralBase& other) const {
- if (!ShapeUtil::Compatible(shape(), other.shape())) {
- return false;
- }
-
- return root_piece().ForEachSubpieceWithBool(
- [&](const ShapeIndex& index, const Piece& piece) {
- if (!ShapeUtil::IsArray(piece.subshape())) {
- return true;
- }
-
- const Piece& other_piece = other.piece(index);
- if (!piece.EqualElements(other_piece)) {
- return false;
- }
- return true;
- });
-}
-
-namespace {
-
-template <typename NativeT>
-static bool AllElementsEqualValue(tensorflow::gtl::ArraySlice<NativeT> data,
- NativeT value) {
- for (int64 i = 0; i < data.size(); ++i) {
- if (data[i] != value) {
- return false;
- }
- }
- return true;
-}
-
-} // namespace
-
-bool LiteralBase::IsAll(int8 value) const {
- return root_piece().ForEachSubpieceWithBool([&](const ShapeIndex& index,
- const Piece& piece) {
- if (!ShapeUtil::IsArray(piece.subshape())) {
- return true;
- }
-
- auto piece_is_all = [&]() {
- switch (shape().element_type()) {
- case U8:
- if (value >= 0) {
- return AllElementsEqualValue<uint8>(piece.data<uint8>(), value);
- }
- return false;
- case U32:
- if (value >= 0) {
- return AllElementsEqualValue<uint32>(piece.data<uint32>(), value);
- }
- return false;
- case U64:
- if (value >= 0) {
- return AllElementsEqualValue<uint64>(piece.data<uint64>(), value);
- }
- return false;
- case S8:
- return AllElementsEqualValue<int8>(piece.data<int8>(), value);
- case S32:
- return AllElementsEqualValue<int32>(piece.data<int32>(), value);
- case S64:
- return AllElementsEqualValue<int64>(piece.data<int64>(), value);
- case F32:
- return AllElementsEqualValue<float>(piece.data<float>(), value);
- case F64:
- return AllElementsEqualValue<double>(piece.data<double>(), value);
- case F16:
- return AllElementsEqualValue<half>(piece.data<half>(),
- static_cast<half>(value));
- case BF16:
- return AllElementsEqualValue<bfloat16>(piece.data<bfloat16>(),
- static_cast<bfloat16>(value));
- case PRED:
- if (value == 0) {
- return AllElementsEqualValue<bool>(piece.data<bool>(), false);
- }
- if (value == 1) {
- return AllElementsEqualValue<bool>(piece.data<bool>(), true);
- }
- return false;
- default:
- return false;
- }
- return false;
- };
-
- if (!piece_is_all()) {
- return false;
- }
- return true;
- });
-}
-
-bool LiteralBase::IsAllFloat(float value) const {
- return root_piece().ForEachSubpieceWithBool(
- [&](const ShapeIndex& index, const Piece& piece) {
- if (!ShapeUtil::IsArray(piece.subshape())) {
- return true;
- }
-
- auto piece_is_all = [&]() {
- switch (shape().element_type()) {
- case F32:
- return AllElementsEqualValue<float>(piece.data<float>(), value);
- case F64:
- return AllElementsEqualValue<double>(piece.data<double>(), value);
- case F16:
- return AllElementsEqualValue<half>(piece.data<half>(),
- static_cast<half>(value));
- case BF16:
- return AllElementsEqualValue<bfloat16>(
- piece.data<bfloat16>(), static_cast<bfloat16>(value));
- default:
- return false;
- }
- };
- if (!piece_is_all()) {
- return false;
- }
- return true;
- });
-}
-
-bool LiteralBase::IsAllComplex(complex64 value) const {
- switch (shape().element_type()) {
- case C64:
- return AllElementsEqualValue<complex64>(root_piece().data<complex64>(),
- value);
- default:
- return false;
- }
-}
-
-bool LiteralBase::IsAllFirst() const {
- return root_piece().ForEachSubpieceWithBool(
- [&](const ShapeIndex& index, const Piece& piece) {
- if (!ShapeUtil::IsArray(piece.subshape())) {
- return true;
- }
-
- // Empty shapes are not all the first element since there is no first
- // element.
- if (ShapeUtil::IsZeroElementArray(piece.subshape())) {
- return false;
- }
- auto piece_is_all = [&]() {
- switch (piece.subshape().element_type()) {
- case PRED: {
- auto data = piece.data<bool>();
- return AllElementsEqualValue<bool>(data, data[0]);
- }
- // 8 bit types
- case S8: {
- auto data = piece.data<int8>();
- return AllElementsEqualValue<int8>(data, data[0]);
- }
- case U8: {
- auto data = piece.data<uint8>();
- return AllElementsEqualValue<uint8>(data, data[0]);
- }
- // 16 bit types
- case BF16: {
- auto data = piece.data<bfloat16>();
- return AllElementsEqualValue<bfloat16>(data, data[0]);
- }
- case F16: {
- auto data = piece.data<half>();
- return AllElementsEqualValue<half>(data, data[0]);
- }
- case S16: {
- auto data = piece.data<int16>();
- return AllElementsEqualValue<int16>(data, data[0]);
- }
- case U16: {
- auto data = piece.data<uint16>();
- return AllElementsEqualValue<uint16>(data, data[0]);
- }
- // 32 bit types
- case F32: {
- auto data = piece.data<float>();
- return AllElementsEqualValue<float>(data, data[0]);
- }
- case U32: {
- auto data = piece.data<uint32>();
- return AllElementsEqualValue<uint32>(data, data[0]);
- }
- case S32: {
- auto data = piece.data<int32>();
- return AllElementsEqualValue<int32>(data, data[0]);
- }
- // 64 bit types
- case C64: {
- auto data = piece.data<complex64>();
- return AllElementsEqualValue<complex64>(data, data[0]);
- }
- case F64: {
- auto data = piece.data<double>();
- return AllElementsEqualValue<double>(data, data[0]);
- }
- case S64: {
- auto data = piece.data<int64>();
- return AllElementsEqualValue<int64>(data, data[0]);
- }
- case U64: {
- auto data = piece.data<uint64>();
- return AllElementsEqualValue<uint64>(data, data[0]);
- }
- default:
- return false;
- }
- };
-
- if (!piece_is_all()) {
- return false;
- }
- return true;
- });
-}
-
-bool LiteralBase::IsZero(tensorflow::gtl::ArraySlice<int64> indices) const {
- CHECK(ShapeUtil::IsArray(shape()));
- switch (shape().element_type()) {
- case U8:
- return Get<uint8>(indices) == 0;
- case U32:
- return Get<uint32>(indices) == 0;
- case U64:
- return Get<uint64>(indices) == 0;
- case S8:
- return Get<int8>(indices) == 0;
- case S32:
- return Get<int32>(indices) == 0;
- case S64:
- return Get<int64>(indices) == 0;
- case F32:
- return Get<float>(indices) == 0.0f;
- case F64:
- return Get<double>(indices) == 0.0;
- case C64:
- return Get<complex64>(indices) == complex64(0.0f, 0.0f);
- case F16:
- return Get<half>(indices) == static_cast<half>(0.0f);
- case BF16:
- return Get<bfloat16>(indices) == static_cast<bfloat16>(0.0f);
- case PRED:
- return Get<bool>(indices) == false;
- default:
- LOG(FATAL) << "Input literal must be an array.";
- }
-}
-
-namespace {
-
-template <typename RepeatedFieldT, typename NativeT>
-void CopyToRepeatedField(RepeatedFieldT* dest,
- const tensorflow::gtl::ArraySlice<NativeT> src) {
- *dest = RepeatedFieldT(src.begin(), src.end());
-}
-
-} // namespace
-
-void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const {
- *proto->mutable_shape() = subshape();
- switch (subshape().element_type()) {
- case PRED:
- CopyToRepeatedField(proto->mutable_preds(), data<bool>());
- break;
- case U8:
- proto->set_u8s(static_cast<const unsigned char*>(data<uint8>().data()),
- element_count());
- break;
- case U32:
- CopyToRepeatedField(proto->mutable_u32s(), data<uint32>());
- break;
- case U64:
- CopyToRepeatedField(proto->mutable_u64s(), data<uint64>());
- break;
- case S32:
- CopyToRepeatedField(proto->mutable_s32s(), data<int32>());
- break;
- case S64:
- CopyToRepeatedField(proto->mutable_s64s(), data<int64>());
- break;
- case F16:
- *proto->mutable_f16s() = string(
- reinterpret_cast<const char*>(data<half>().data()), size_bytes());
- if (!kLittleEndian) {
- ConvertEndianShort(proto->mutable_f16s());
- }
- break;
- case BF16:
- *proto->mutable_bf16s() = string(
- reinterpret_cast<const char*>(data<bfloat16>().data()), size_bytes());
- if (!kLittleEndian) {
- ConvertEndianShort(proto->mutable_bf16s());
- }
- break;
- case F32:
- CopyToRepeatedField(proto->mutable_f32s(), data<float>());
- break;
- case F64:
- CopyToRepeatedField(proto->mutable_f64s(), data<double>());
- break;
- case C64:
- for (complex64 value : data<complex64>()) {
- proto->add_c64s(value.real());
- proto->add_c64s(value.imag());
- }
- break;
- case TUPLE:
- case TOKEN:
- // Nothing to do but assign the shape which is done above.
- return;
- default:
- LOG(FATAL) << "Unhandled primitive type " << subshape().element_type();
- }
-}
-
-const void* LiteralBase::Piece::untyped_data() const {
- CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
- return buffer();
-}
-
-void* LiteralBase::Piece::untyped_data() {
- CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
- return buffer();
-}
-
-namespace {
-
-template <typename RepeatedFieldT, typename NativeT>
-Status CopyFromRepeatedField(tensorflow::gtl::MutableArraySlice<NativeT> dest,
- const RepeatedFieldT& src) {
- if (dest.size() != src.size()) {
- return InvalidArgument(
- "Expected %lu elements in LiteralProto repeated field, has %d",
- dest.size(), src.size());
- }
- std::copy(src.begin(), src.end(), dest.begin());
- return Status::OK();
-}
-
-} // namespace
-
-Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) {
- // These conditions should have been checked in Literal::CreateFromProto.
- TF_RET_CHECK(proto.has_shape());
- TF_RET_CHECK(LayoutUtil::HasLayout(proto.shape()));
- TF_RET_CHECK(ShapeUtil::Equal(proto.shape(), subshape()));
-
- switch (subshape().element_type()) {
- case PRED:
- TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<bool>(), proto.preds()));
- break;
- case U8: {
- auto u8_data = data<uint8>();
- TF_RET_CHECK(proto.u8s().size() == u8_data.size());
- std::copy(proto.u8s().begin(), proto.u8s().end(), u8_data.begin());
- } break;
- case S32:
- TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<int32>(), proto.s32s()));
- break;
- case S64:
- TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<int64>(), proto.s64s()));
- break;
- case U32:
- TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<uint32>(), proto.u32s()));
- break;
- case U64:
- TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<uint64>(), proto.u64s()));
- break;
- case F16: {
- const string& s(proto.f16s());
- TF_RET_CHECK(data<half>().size() * sizeof(half) == s.size());
- memcpy(untyped_data(), s.data(), s.size());
- if (!kLittleEndian) {
- ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
- }
- } break;
-
- case BF16: {
- const string& s(proto.bf16s());
- TF_RET_CHECK(data<bfloat16>().size() * sizeof(bfloat16) == s.size());
- memcpy(untyped_data(), s.data(), s.size());
- if (!kLittleEndian) {
- ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
- }
- } break;
- case F32:
- TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<float>(), proto.f32s()));
- break;
- case F64:
- TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<double>(), proto.f64s()));
- break;
- case C64: {
- auto complex_data = data<complex64>();
- TF_RET_CHECK(proto.c64s_size() == complex_data.size() * 2);
- for (int64 i = 0; i < complex_data.size(); ++i) {
- complex_data[i] = complex64{proto.c64s(i * 2), proto.c64s(i * 2 + 1)};
- }
- } break;
- case TUPLE:
- LOG(FATAL) << "Should not be called on tuple shapes: "
- << ShapeUtil::HumanString(subshape());
- break;
- default:
- LOG(FATAL) << "Unhandled primitive type " << subshape().element_type();
- }
- return Status::OK();
-}
-
-LiteralProto LiteralBase::ToProto() const {
- LiteralProto proto;
- root_piece().ForEachSubpiece(
- [&](const ShapeIndex& index, const Piece& piece) {
- LiteralProto* proto_piece = &proto;
- for (int64 i : index) {
- while (proto_piece->tuple_literals_size() <= i) {
- proto_piece->add_tuple_literals();
- }
- proto_piece = proto_piece->mutable_tuple_literals(i);
- }
- piece.WriteToProto(proto_piece);
- });
-
- if (LayoutUtil::IsSparseArray(shape())) {
- CopyToRepeatedField(proto.mutable_sparse_indices(),
- sparse_indices()->data());
- }
-
- return proto;
-}
-
-/* static */
-StatusOr<std::unique_ptr<Literal>> Literal::CreateFromProto(
- const LiteralProto& proto) {
- if (!proto.has_shape()) {
- return InvalidArgument("LiteralProto has no shape");
- }
- if (!LayoutUtil::HasLayout(proto.shape())) {
- return InvalidArgument("LiteralProto has no layout");
- }
-
- auto literal = MakeUnique<Literal>(proto.shape());
-
- TF_RETURN_IF_ERROR(literal->root_piece_->ForEachMutableSubpieceWithStatus(
- [&](const ShapeIndex& index, Piece* piece) {
- const LiteralProto* proto_element = &proto;
- for (int64 i : index) {
- CHECK(i < proto_element->tuple_literals_size());
- proto_element = &proto_element->tuple_literals(i);
- }
-
- if (ShapeUtil::IsTuple(piece->subshape())) {
- if (proto_element->tuple_literals_size() !=
- ShapeUtil::TupleElementCount(piece->subshape())) {
- return InvalidArgument(
- "Expected %lld tuple elements in LiteralProto, has %d",
- ShapeUtil::TupleElementCount(piece->subshape()),
- proto_element->tuple_literals_size());
- }
- return Status::OK();
- }
- if (piece->subshape().element_type() == TOKEN) {
- return Status::OK();
- }
-
- CHECK(ShapeUtil::IsArray(piece->subshape()));
- TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element));
-
- return Status::OK();
- }));
-
- return std::move(literal);
-}
-
-/* static */ string Literal::MultiIndexAsString(
+/* static */ string LiteralUtil::MultiIndexAsString(
tensorflow::gtl::ArraySlice<int64> multi_index) {
return StrCat("{", tensorflow::str_util::Join(multi_index, ","), "}");
}
-const void* LiteralBase::untyped_data(const ShapeIndex& shape_index) const {
- return piece(shape_index).untyped_data();
-}
-
-void* Literal::untyped_data(const ShapeIndex& shape_index) {
- return piece(shape_index).untyped_data();
-}
-
-int64 LiteralBase::size_bytes(const ShapeIndex& shape_index) const {
- return piece(shape_index).size_bytes();
-}
-
-string LiteralBase::GetR1U8AsString() const {
- CHECK(ShapeUtil::IsArray(shape()));
- CHECK_EQ(ShapeUtil::Rank(shape()), 1);
- CHECK_EQ(shape().element_type(), U8);
- return string(tensorflow::bit_cast<const char*>(data<uint8>().data()),
- ShapeUtil::ElementsIn(shape()));
-}
-
-void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) {
- CHECK(ShapeUtil::IsTuple(shape));
- for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
- const Shape& subshape = shape.tuple_shapes(i);
-
- auto child_piece = Piece();
- child_piece.set_subshape(&subshape);
-
- if (ShapeUtil::IsTuple(subshape)) {
- BuildPieceSubtree(subshape, &child_piece);
- }
-
- piece->emplace_back(std::move(child_piece));
- }
-}
-
-LiteralSlice::LiteralSlice(const LiteralBase& literal)
- : LiteralBase(), root_piece_(&literal.root_piece()) {}
-
-LiteralSlice::LiteralSlice(const LiteralBase& literal,
- const ShapeIndex& view_root)
- : LiteralBase(), root_piece_(&literal.piece(view_root)) {}
-
-BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape)
- : LiteralBase(), shape_(MakeUnique<Shape>(shape)) {
- CHECK(ShapeUtil::IsArray(*shape_));
- CHECK(LayoutUtil::HasLayout(*shape_));
-
- root_piece_ = Piece();
- root_piece_.set_buffer(const_cast<char*>(src_buf_ptr));
- root_piece_.set_subshape(shape_.get());
-}
-
-BorrowingLiteral::BorrowingLiteral(
- tensorflow::gtl::ArraySlice<const char*> src_buf_ptrs, const Shape& shape)
- : LiteralBase(), shape_(MakeUnique<Shape>(shape)) {
- CHECK(ShapeUtil::IsTuple(*shape_));
- CHECK(!ShapeUtil::IsNestedTuple(*shape_));
- CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_));
- root_piece_ = Piece();
- root_piece_.set_subshape(shape_.get());
- BuildPieceSubtree(*shape_, &root_piece_);
-
- for (int i = 0; i < src_buf_ptrs.size(); ++i) {
- const auto& src_shape = shape_->tuple_shapes(i);
- CHECK(ShapeUtil::IsArray(src_shape));
- root_piece_.child(i).set_buffer(const_cast<char*>(src_buf_ptrs[i]));
- }
-}
-
} // namespace xla
diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h
index 37ca8ea9f1..e3737a9d00 100644
--- a/tensorflow/compiler/xla/literal_util.h
+++ b/tensorflow/compiler/xla/literal_util.h
@@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/index_util.h"
#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -51,679 +52,12 @@ limitations under the License.
namespace xla {
-// Forward declare Literal and LiteralSlice class to be used by the creation
-// methods in the base class.
-class Literal;
-class LiteralSlice;
-
-// Abstract base class for literals.
-class LiteralBase {
+class LiteralUtil {
public:
- virtual ~LiteralBase() = 0;
-
- // Literals are equal if they have compatible shapes and the same data
- // values. Layout is not compared.
- bool operator==(const LiteralBase& other) const;
- bool operator!=(const LiteralBase& other) const { return !(*this == other); }
-
- // Returns the shape of the literal.
- const Shape& shape() const { return root_piece().subshape(); }
-
- // Serialize to proto.
- LiteralProto ToProto() const;
-
- // Returns an ArraySlice of the array for this literal for the given NativeT
- // (e.g., float). CHECKs if the subshape of the literal at the given
- // ShapeIndex is not array. See primitive_util.h for the mapping from XLA type
- // to native type.
- template <typename NativeT>
- tensorflow::gtl::ArraySlice<NativeT> data(
- const ShapeIndex& shape_index = {}) const;
-
- // Returns a const pointer to the sparse index array. Returns nullptr if the
- // literal is not a sparse array.
- const SparseIndexArray* sparse_indices(
- const ShapeIndex& shape_index = {}) const;
-
- // Returns a const pointer to (or size of) the underlying buffer holding the
- // array at the given shape index. CHECKs if the subshape of the literal at
- // the given ShapeIndex is not array.
- const void* untyped_data(const ShapeIndex& shape_index = {}) const;
- int64 size_bytes(const ShapeIndex& shape_index = {}) const;
-
- // Returns this literal's data as a string. This literal must be a rank-1 U8
- // array.
- string GetR1U8AsString() const;
-
- // Returns a string representation of the literal value.
- // Warning: this function can take minutes for multi-million element Literals.
- string ToString(bool print_layout = false) const;
-
- // Gets an element in the literal at the given index. The multi_index is
- // CHECKed against the dimension sizes.
- template <typename NativeT>
- NativeT Get(tensorflow::gtl::ArraySlice<int64> multi_index,
- const ShapeIndex& shape_index) const;
- // Overloads of Get for array literals. CHECKs if the literal is not
- // array-shaped and dense.
- template <typename NativeT>
- NativeT Get(tensorflow::gtl::ArraySlice<int64> multi_index) const;
-
- // Returns the element value at index (0, ..., 0), however many zeroes are
- // required for that index.
- template <typename NativeT>
- NativeT GetFirstElement() const;
-
- // As Get(), but determines the correct type and converts the value
- // into text.
- string GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index,
- const ShapeIndex& shape_index = {}) const;
- // As GetSparseElement(), but determines the correct type and converts the
- // value into text.
- string GetSparseElementAsString(int64 sparse_element_number,
- const ShapeIndex& shape_index = {}) const;
- // As Get(), but determines the correct type and converts the value into
- // int64. This literal must be an array.
- StatusOr<int64> GetIntegralAsS64(
- tensorflow::gtl::ArraySlice<int64> multi_index) const;
-
- // Returns the multi-index of the element in a sparse literal at the given
- // sparse element number. The sparse element number is the position with in
- // the sparse array's list of (index, value) pairs, and is checked against the
- // total number of (index, value) pairs in the sparse array.
- tensorflow::gtl::ArraySlice<int64> GetSparseIndex(
- int64 sparse_element_number, const ShapeIndex& shape_index = {}) const;
-
- // Returns the value of the element in a sparse literal at the given sparse
- // element number. The sparse element number is the position with in the
- // sparse array's list of (index, value) pairs, and is checked against the
- // total number of (index, value) pairs in the sparse array.
- template <typename NativeT>
- NativeT GetSparseElement(int64 sparse_element_number,
- const ShapeIndex& shape_index = {}) const;
-
- // Invokes the "per cell" callback for each element in the provided
- // literal with the element's indices and a string representation of
- // the element's value.
- //
- // This function is useful if you want a polymorphic representation
- // of the tensor's elements (turning it to a string for something
- // like representation in a protobuf).
- //
- // This literal must have a dense layout.
- void EachCellAsString(
- const std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
- const string& value)>& per_cell) const;
- template <typename NativeT>
- void EachCell(std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
- NativeT value)>
- per_cell) const;
-
- // Returns whether every element in this literal is equal to value.
- //
- // value is an int8 because we expect this to be called with small
- // compile-time constants (0, -1, etc.) and so that whatever value you pass
- // can be represented exactly by floating-point types as small as 16 bits.
- //
- // If value doesn't fit in this literal's type, returns false. Values of 1/0
- // are considered equal to true/false; other values are not considered equal
- // to true. Also if this literal is not array-shaped false is returned.
- bool IsAll(int8 value) const;
-
- // Like IsAll(const Literal&, int8), except we check whether the literal is
- // equal to a particular floating-point number.
- //
- // If the literal is not a floating-point value, this always returns false.
- //
- // This casts value to the type of literal, then compares using ==. The usual
- // admonishments about floating-point equality checks apply. We expect you to
- // use this to check for values that can be expressed precisely as a float,
- // e.g. -0.5. Also if this literal is not array-shaped false is returned.
- bool IsAllFloat(float value) const;
-
- // Like IsAll(const Literal&, int8), except we check whether the literal is
- // equal to a particular complex number.
- //
- // If the literal is not a complex value, this always returns false.
- //
- // This casts value to the type of literal, then compares using ==. The usual
- // admonishments about floating-point equality checks apply. We expect you to
- // use this to check for complex values that can be expressed precisely as
- // float pairs e.g. (-0.5, 1.0).
- //
- // This literal must have a dense layout.
- bool IsAllComplex(complex64 value) const;
-
- // Literal consists entirely of the first element of the literal.
- bool IsAllFirst() const;
-
- // Returns whether this literal is zero at the specified index. This literal
- // must be an array with a dense layout.
- bool IsZero(tensorflow::gtl::ArraySlice<int64> indices) const;
-
- // Returns the count of the elements in the array at the given shape index in
- // this literal.
- int64 element_count(const ShapeIndex& index = {}) const {
- return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index));
- }
-
- // Returns the count of the elements in the sparse array at the given shape
- // index in this literal, which will be no larger than
- // LayoutUtil::MaxSparseElements(SetSubshape(shape(), index).layout()).
- int64 sparse_element_count() const;
-
- // Compute a hash for this literal. This literal must not be a sparse tensor
- // or a tuple containing a sparse tensor.
- size_t Hash() const;
-
- // Converts this literal to the given shape. Returns an error is the
- // conversion is not possible.
- //
- // round_f32_to_bf16: if true, converting F32 elements to BF16 uses rounding
- // instead of truncation; otherwise, truncation is used.
- //
- // TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes
- // the default behavior.
- StatusOr<std::unique_ptr<Literal>> ConvertToShape(
- const Shape& dest_shape, bool round_f32_to_bf16 = false) const;
-
- // Converts this literal to another primitive type using a bitcast
- // conversion. The to and from primitive types must have the same bit
- // width. Returns an error if the conversion is not possible. This literal
- // must be array-shaped.
- StatusOr<std::unique_ptr<Literal>> BitcastConvert(
- PrimitiveType primitive_dest_type) const;
-
- // Converts this literal to another primitive type. Returns an error if the
- // conversion is not possible. This literal must be array-shaped.
- StatusOr<std::unique_ptr<Literal>> Convert(
- PrimitiveType primitive_dest_type) const;
+ LiteralUtil() = delete;
// Returns a literal scalar representing the first element.
- Literal GetFirstScalarLiteral() const;
-
- // Clones the underlying buffers into a new Literal, or new
- // std::unique_ptr<Literal>.
- Literal Clone() const;
- std::unique_ptr<Literal> CloneToUnique() const;
-
- // TODO(b/67651157): The methods below which perform computation on Literals
- // (Reshape, Slice, etc) should be moved elsewhere, and perhaps combined with
- // evaluator code which operates on Literals.
- //
- // Creates a new value that has the equivalent value as this
- // literal, but conforms to new_layout; e.g. a literal matrix that was in {0,
- // 1} minor-to-major dimension layout can be re-layed-out as {1, 0}
- // minor-to-major dimension layout and the value in the cell at any given
- // logical index (i0, i1) will be the same.
- //
- // For tuple shaped literals, shape_index should be used to select the inner
- // array that the new layout applies to.
- //
- // Note: this is useful when the client wants to ensure that a value placed in
- // the XLA allocation tracker has a particular layout; for efficiency
- // purposes or avoiding unimplemented operation/layout combinations.
- std::unique_ptr<Literal> Relayout(const Layout& new_layout,
- const ShapeIndex& shape_index = {}) const;
-
- // An overload of Relayout which changes the layout of the entire shape rather
- // than being limited to a single array within the shape.
- std::unique_ptr<Literal> Relayout(const Shape& shape_with_layout) const;
-
- // Creates a new literal by reshaping this literal to have the given
- // dimensions. The total number of elements must not change; The
- // implementation currently only supports monotonic dim0-major layouts.
- // This literal must be an array.
- StatusOr<std::unique_ptr<Literal>> Reshape(
- tensorflow::gtl::ArraySlice<int64> dimensions) const;
-
- // Creates a new literal by broadcasting this literal with `dimensions` to
- // yield a literal of shape `result_shape`.
- StatusOr<std::unique_ptr<Literal>> Broadcast(
- const Shape& result_shape,
- tensorflow::gtl::ArraySlice<int64> dimensions) const;
-
- // Creates a new literal by reordering the dimensions of this literal.
- // The given `permutation` must be a permutation of the dimension numbers
- // in the original literal, and it specifies the order of the new dimensions
- // in the result literal (i.e., new_order[i] = old_order[permutation[i]]).
- // For example, a transpose call on a literal of shape [3 x 8 x 4] and
- // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8].
- // This literal must be an array.
- std::unique_ptr<Literal> Transpose(
- tensorflow::gtl::ArraySlice<int64> permutation) const;
-
- // Creates a sub-array from this literal by extracting the indices
- // [start_index, limit_index) of each dimension. The result literal has the
- // same rank and layout as for the given literal. The number of indices in
- // start_indices and limit_indices must be the rank of the literal, and the
- // indices follow the order of the dimensions.
- // This literal must be an array.
- std::unique_ptr<Literal> Slice(
- tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices) const;
-
- // Creates a literal with a prepended dimension with bound "times"; e.g. a
- // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this
- // literal replicated four times.
- // This literal must be an array.
- template <typename NativeT>
- std::unique_ptr<Literal> Replicate(int64 times) const;
-
- // Creates a new Literal object with the shape specified as parameter.
- // The content of the literal values is the default value of the primitive
- // type of literal itself (0 for numeric types, and false for predicates).
- //
- // Note: It's an antipattern to use this method then immediately call
- // Literal::Populate on the result (since that results in zero initialization,
- // then reinitialization. Conside if a call to MakeUnique<Literal>(shape),
- // followed by the call to Literal::Populate can be used instead.
- static std::unique_ptr<Literal> CreateFromShape(const Shape& shape);
-
- protected:
- // A data structure representing a subshape at a particular ShapeIndex within
- // the literal. For array-shaped ShapeIndexes, this data structure holds the
- // pointer to the memory allocated for the array data.
- class Piece {
- public:
- // Returns the buffer holding the array data for this piece as an array
- // slice. This piece must be array-shaped.
- template <typename NativeT>
- tensorflow::gtl::ArraySlice<NativeT> data() const;
- template <typename NativeT>
- tensorflow::gtl::MutableArraySlice<NativeT> data();
-
- // Returns the buffer holding the array data for this piece as a void*. This
- // piece must be array-shaped.
- void* untyped_data();
- const void* untyped_data() const;
-
- // Gets or sets an element in the array at the given index. The multi_index
- // is CHECKed against the dimension sizes of the array. This piece must be
- // array-shaped.
- template <typename NativeT>
- NativeT Get(tensorflow::gtl::ArraySlice<int64> index) const;
- template <typename NativeT>
- void Set(tensorflow::gtl::ArraySlice<int64> index, NativeT value);
-
- // Gets/sets the buffer holding the array data.
- char* buffer() const { return buffer_; }
- void set_buffer(char* buffer) { buffer_ = buffer; }
-
- // The array of multi-indices that provide the locations of non-zero
- // elements in a sparse array. Only used if
- // LayoutUtil::IsSparseArray(shape()) is true.
- SparseIndexArray* sparse_indices() const { return sparse_indices_; }
- void set_sparse_indices(SparseIndexArray* sparse_indices) {
- sparse_indices_ = sparse_indices;
- }
-
- // Gets or sets the subshape of this piece. This reference points to a
- // subshape within the shape in the containing Literal (Literal::shape_).
- const Shape& subshape() const { return *subshape_; }
- void set_subshape(const Shape* subshape) { subshape_ = subshape; }
-
- // Returns the size in bytes of the buffer holding the array data.
- int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); }
-
- // Returns the number of elements in this piece's array.
- int64 element_count() const {
- // If this is a sparse array, use the number of elements represented by
- // the indices in the associated SparseIndexArray.
- return LayoutUtil::IsSparseArray(subshape())
- ? sparse_indices()->index_count()
- : ShapeUtil::ElementsIn(subshape());
- }
-
- // Returns the child piece at 'index' of this piece.
- Piece& child(int64 index) { return children_[index]; }
-
- // Adds a child piece to this piece's children.
- void emplace_back(Piece child_piece) {
- children_.emplace_back(std::move(child_piece));
- }
-
- // Returns the size of children pieces of this piece.
- int64 children_size() { return children_.size(); }
-
- // Visitor functions that recursively traverses the piece and calls the
- // given function at each child piece. The function has the type:
- // void (const ShapeIndex& index, const Piece& piece)
- template <typename Fn>
- void ForEachSubpiece(const Fn& func) const {
- ShapeIndex index;
- return ForEachHelper(
- [&func](const ShapeIndex& index, const Piece& piece) {
- func(index, piece);
- return Status::OK();
- },
- *this, &index)
- .IgnoreError();
- }
- // Same as above, but the function has the type:
- // Status (const ShapeIndex& index, const Piece& piece)
- // The first non-OK return value is returned by the function.
- template <typename Fn>
- Status ForEachSubpieceWithStatus(const Fn& func) const {
- ShapeIndex index;
- return ForEachHelper(func, *this, &index);
- }
- // Same as above, but the function has the type:
- // Bool (const ShapeIndex& index, const Piece& piece)
- // The first non-true return value is returned by the function.
- template <typename Fn>
- bool ForEachSubpieceWithBool(const Fn& func) const {
- ShapeIndex index;
- return ForEachHelperBool(func, *this, &index);
- }
- // Same as above, but the function has the type:
- // Void (const ShapeIndex& index, Piece& piece)
- template <typename Fn>
- void ForEachMutableSubpiece(const Fn& func) {
- ShapeIndex index;
- return ForEachMutableHelper(
- [&func](const ShapeIndex& index, Piece* piece) {
- func(index, piece);
- return Status::OK();
- },
- const_cast<xla::LiteralBase::Piece*>(this), &index)
- .IgnoreError();
- }
- // Same as above, but the function has the type:
- // Status (const ShapeIndex& index, Piece& piece)
- // The first non-OK return value is returned by the function.
- template <typename Fn>
- Status ForEachMutableSubpieceWithStatus(const Fn& func) {
- ShapeIndex index;
- return ForEachMutableHelper(
- func, const_cast<xla::LiteralBase::Piece*>(this), &index);
- }
-
- // Returns true if this piece and 'other' contain the same data. This piece
- // and 'other' must be array-shaped and compatible.
- bool EqualElements(const Piece& other) const;
-
- // Writes the shape and data (if array-shaped) into the given proto.
- void WriteToProto(LiteralProto* proto) const;
-
- // Copy the data from 'src' into this piece's buffer. Shapes of this piece
- // and src must be compatible.
- Status CopyFrom(const Piece& src);
-
- // Copies the data from the given proto into this piece. The shape of this
- // piece must be equal (not just compatible) to the shape of the proto.
- Status CopyFromProto(const LiteralProto& proto);
-
- // Sorts the elements in a sparse array.
- void SortSparseElements();
-
- private:
- // Helpers for traversing the piece via ForEachSubpiece rooted at 'index'.
- // The first non-OK (or non-true) value is returned by the function.
- // The callable 'func' has the same signature as described above in
- // ForEachSubpiece*.
- template <typename Fn>
- Status ForEachHelper(const Fn& func, const Piece& piece,
- ShapeIndex* index) const {
- TF_RETURN_IF_ERROR(func(*index, piece));
- for (int64 i = 0; i < piece.children_.size(); ++i) {
- index->push_back(i);
- TF_RETURN_IF_ERROR(ForEachHelper(func, piece.children_[i], index));
- index->pop_back();
- }
- return Status::OK();
- }
- template <typename Fn>
- bool ForEachHelperBool(const Fn& func, const Piece& piece,
- ShapeIndex* index) const {
- if (!func(*index, piece)) {
- return false;
- }
- for (int64 i = 0; i < piece.children_.size(); ++i) {
- index->push_back(i);
- if (!ForEachHelperBool(func, piece.children_[i], index)) {
- return false;
- }
- index->pop_back();
- }
- return true;
- }
- template <typename Fn>
- Status ForEachMutableHelper(const Fn& func, Piece* piece,
- ShapeIndex* index) {
- TF_RETURN_IF_ERROR(func(*index, piece));
- for (int64 i = 0; i < piece->children_.size(); ++i) {
- index->push_back(i);
- TF_RETURN_IF_ERROR(
- ForEachMutableHelper(func, &piece->children_[i], index));
- index->pop_back();
- }
- return Status::OK();
- }
-
- // Recursive helper for EqualElements.
- template <typename NativeT>
- bool EqualElementsInternal(const Piece& other,
- std::vector<int64>* multi_index) const;
-
- // Helper for SortSparseElements that has the element type as a template
- // parameter.
- template <typename NativeT>
- void SortSparseElementsInternal();
-
- // For array-shaped pieces, this is the buffer holding the literal data.
- char* buffer_ = nullptr;
-
- // For sparse arrays, this is the array of indices.
- SparseIndexArray* sparse_indices_ = nullptr;
-
- // The shape of piece. This points into the shape of the containing Literal
- // (Literal::shape_).
- const Shape* subshape_ = nullptr;
-
- // Children pieces for tuple shaped pieces.
- std::vector<Piece> children_ = {};
- }; // class Piece
-
- const Piece& piece(const ShapeIndex& shape_index) const {
- Piece* piece = &const_cast<Piece&>(root_piece());
- for (const auto i : shape_index) {
- DCHECK_GE(i, 0);
- DCHECK_LT(i, piece->children_size());
- piece = &piece->child(i);
- }
- return *piece;
- }
-
- // Returns the piece at the root of the shape.
- virtual const Piece& root_piece() const = 0;
-
- // LiteralSlice and Literal must access Pieces of other Literals.
- friend class Literal;
- friend class LiteralSlice;
- friend class BorrowingLiteral;
-
- private:
- template <typename NativeT>
- std::unique_ptr<Literal> SliceInternal(
- const Shape& result_shape,
- tensorflow::gtl::ArraySlice<int64> start_indices) const;
-};
-
-// Class representing literal values in XLA.
-//
-// The underlying buffer and shape is always owned by this class.
-class Literal : public LiteralBase {
- public:
- Literal() : Literal(ShapeUtil::MakeNil()) {}
-
- // Create a literal of the given shape. The literal is allocated sufficient
- // memory to hold the shape. Memory is uninitialized.
- explicit Literal(const Shape& shape);
- virtual ~Literal();
-
- // Literals are moveable, but not copyable. To copy a literal use
- // Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies
- // of literals which can be expensive.
- Literal(const Literal& other) = delete;
- Literal& operator=(const Literal& other) = delete;
- Literal(Literal&& other);
- // 'allocate_arrays' indicates whether to allocate memory for the arrays in
- // the shape. If false, buffer pointers inside of the Literal::Pieces are set
- // to nullptr.
- Literal(const Shape& shape, bool allocate_arrays);
- Literal& operator=(Literal&& other);
-
- // TODO(b/67651157): Remove this accessor. Literal users should not be able to
- // mutate the shape as this can produce malformed Literals.
- Shape* mutable_shape_do_not_use() { return shape_.get(); }
-
- // Returns a MutableArraySlice view of the array for this literal for the
- // given NativeT (e.g., float). CHECKs if the subshape of the literal at the
- // given ShapeIndex is not array. See primitive_util.h for the mapping from
- // XLA type to native type.
- template <typename NativeT>
- tensorflow::gtl::MutableArraySlice<NativeT> data(
- const ShapeIndex& shape_index = {});
- // Unhide const method from parent class.
- using LiteralBase::data;
-
- // Returns a pointer to the sparse index array. Returns nullptr if the literal
- // is not a sparse array.
- SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {});
-
- // Returns a pointer to the underlying buffer holding the array at the given
- // shape index. CHECKs if the subshape of the literal at the given ShapeIndex
- // is not array.
- void* untyped_data(const ShapeIndex& shape_index = {});
- // Unhide const method from parent class.
- using LiteralBase::untyped_data;
-
- // Populates a literal with a sparse layout with the given indices and values.
- // Each index in the indices array is CHECKed against the dimensions in the
- // literal's shape. If sort is true, then the indices and values will be
- // sorted. If sort is false, then the indices and values are assumed to
- // already be in sorted order. See CreateSparse for an example of how data
- // are populated.
- template <typename NativeT>
- void PopulateSparse(SparseIndexArray indices,
- tensorflow::gtl::ArraySlice<NativeT> values,
- bool sort = true);
-
- // Copy values from 'src_literal' rooted at 'src_shape_index' into this
- // literal rooted at 'dest_shape_index'. The subshape of this literal rooted
- // at 'dest_shape_index' must be compatible with the subshape of 'src_literal'
- // rooted at 'src_shape_index', but need not be arrays.
- Status CopyFrom(const LiteralSlice& src_literal,
- const ShapeIndex& dest_shape_index = {},
- const ShapeIndex& src_shape_index = {});
-
- // Similar to CopyFrom, but with move semantincs. The subshape of this literal
- // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal'
- // (layouts and shapes must match), but need not be arrays. The memory
- // allocated in this literal for the subshape at dest_shape_index is
- // deallocated, and the respective buffers are replaced with those in
- // src_literal. Upon return, src_literal is set to a nil shape (empty tuple).
- Status MoveFrom(Literal&& src_literal,
- const ShapeIndex& dest_shape_index = {});
-
- // Copies the values from src_literal, starting at src_base shape indexes,
- // to this literal, starting at dest_base, where the copy size in each
- // dimension is specified by copy_size.
- // The src_literal and this literal must have the same primitive type,
- // src_base+copy_size must fit the source literal dimensions, as well as
- // dest_base+copy_size must fit the destination literal dimensions.
- // Note: if either src_literal or this literal contains dimensions with zero
- // element, then copy_size must be 0 in these dimensions while the
- // corresponding base indices being 0.
- // This literal and 'src_literal' must be arrays.
- Status CopySliceFrom(const LiteralSlice& src_literal,
- tensorflow::gtl::ArraySlice<int64> src_base,
- tensorflow::gtl::ArraySlice<int64> dest_base,
- tensorflow::gtl::ArraySlice<int64> copy_size);
-
- // Copies one element from src_literal[src_index] to (*this)[dest_index].
- Status CopyElementFrom(const LiteralSlice& src_literal,
- tensorflow::gtl::ArraySlice<int64> src_index,
- tensorflow::gtl::ArraySlice<int64> dest_index);
-
- // Sets an element in the literal at the given index. The multi_index is
- // CHECKed against the dimension sizes.
- template <typename NativeT>
- void Set(tensorflow::gtl::ArraySlice<int64> multi_index,
- const ShapeIndex& shape_index, NativeT value);
- // Overloads of Set for array literals. CHECKs if the literal is not
- // array-shaped and dense.
- template <typename NativeT>
- void Set(tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value);
-
- // Appends the given element to the literal. If the elements are not appended
- // in sorted order, then SortSparseElements should be called before calling
- // other methods. This literal must have a sparse layout.
- template <typename NativeT>
- void AppendSparseElement(tensorflow::gtl::ArraySlice<int64> multi_index,
- NativeT value, const ShapeIndex& shape_index = {});
-
- // Sorts the elements in a sparse array.
- void SortSparseElements(const ShapeIndex& shape_index = {});
-
- // As Set(), but truncates `value` to the literal element type before storing.
- // This literal must be an array.
- Status SetIntegralAsS64(tensorflow::gtl::ArraySlice<int64> multi_index,
- int64 value);
-
- // Populate this literal with the given values. Examples:
- //
- // // Populate with floats.
- // Array2D<float> float_values = ...
- // literal.PopulateR2FromArray2D(values);
- //
- // // Populate with int32s.
- // literal.PopulateR2<int32>({{1, 2}, {3, 4}});
- //
- // The shape and element type of this literal must match given values. For
- // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2
- // array of S32.
- template <typename NativeT>
- void PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values);
- void PopulateR1(const tensorflow::core::Bitmap& values);
- template <typename NativeT>
- void PopulateR2(std::initializer_list<std::initializer_list<NativeT>> values);
- template <typename NativeT>
- void PopulateFromArray(const Array<NativeT>& values);
- template <typename NativeT>
- void PopulateR2FromArray2D(const Array2D<NativeT>& values);
- template <typename NativeT>
- void PopulateR3FromArray3D(const Array3D<NativeT>& values);
- template <typename NativeT>
- void PopulateR4FromArray4D(const Array4D<NativeT>& values);
-
- // Populates literal values by calling the generator function for every cell
- // in this literal object.
- //
- // generator must be a callable of the type
- // NativeT(tensorflow::gtl::ArraySlice<int64> indexes) or compatible.
- //
- // This literal must have a dense layout.
- template <typename NativeT, typename FnType>
- Status Populate(const FnType& generator);
-
- // A parallel version of Populate(). This can be used if the generator is
- // thread-safe and the values for the shape's different elements are
- // independent.
- template <typename NativeT, typename FnType>
- Status PopulateParallel(const FnType& generator);
-
- // Fills this literal with the given value.
- template <typename NativeT>
- void PopulateWithValue(NativeT value);
-
- // Factory methods below.
- //
-
- // Serialize from a proto.
- static StatusOr<std::unique_ptr<Literal>> CreateFromProto(
- const LiteralProto& proto);
+ static Literal GetFirstScalarLiteral(const LiteralSlice& literal);
// Creates a new literal of a given rank. To minimize ambiguity (for users
// and the compiler) these CreateR[0-2] methods should explicitly specify the
@@ -889,7 +223,7 @@ class Literal : public LiteralBase {
// As above, but intended to be invoked with move semantics; i.e.
//
// std::vector<std::unique_ptr<Literal>> elements = ...;
- // auto result = Literal::MakeTupleOwned(std::move(elements));
+ // auto result = LiteralUtil::MakeTupleOwned(std::move(elements));
//
// This would have been declared as an overload, but there is ambiguity
// in invocation between the above signature and this one.
@@ -899,7 +233,7 @@ class Literal : public LiteralBase {
// This overload lets you pass a braced list of unique_ptr<Literal>s to
// MakeTupleOwned:
//
- // Literal::MakeTupleOwned(Literal::CreateR1(...), ...).
+ // LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1(...), ...).
//
// Simply relying on the MakeTupleOwned(std::vector<unique_ptr<Literal>>)
// overload doesn't work because std::initializer_list's elements are always
@@ -920,19 +254,6 @@ class Literal : public LiteralBase {
// Create a constant token literal. Token types have no value.
static std::unique_ptr<Literal> CreateToken();
- // Returns a vector containing the tuple elements of this Literal as separate
- // Literals. This Literal must be tuple-shaped and can be a nested tuple. The
- // elements are moved into the new Literals; no data is copied. Upon return
- // this Literal is set to a nil shape (empty tuple)
- std::vector<Literal> DecomposeTuple();
-
- // This operation is the inverse of DecomposeTuple. The given elements are
- // moved into the tuple elements of a new tuple-shaped Literal which is
- // returned. Upon return, each of the Literals in 'elements' is set to a nil
- // shape (empty tuple).
- static Literal MoveIntoTuple(
- tensorflow::gtl::MutableArraySlice<Literal> elements);
-
// Creates a new Literal object with its values havings the primitive_type
// type, and with dimensions defined by the dimensions parameter.
// The content of the literal values is the default value of the primitive
@@ -1000,194 +321,12 @@ class Literal : public LiteralBase {
// dimension 1 equal to 8.
static string MultiIndexAsString(
tensorflow::gtl::ArraySlice<int64> multi_index);
-
- private:
- // Recursively sets the subshapes and buffers of all subpieces rooted at
- // 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in
- // the shape.
- void SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays);
-
- // Returns the piece at the given ShapeIndex.
- Piece& piece(const ShapeIndex& shape_index) {
- return const_cast<Piece&>(LiteralBase::piece(shape_index));
- }
-
- Piece& root_piece() const override { return *root_piece_; };
-
- // Internal template helper for the Literal::CopySliceFrom(), matching its
- // arguments one by one.
- template <typename NativeT>
- Status CopySliceFromInternal(const LiteralBase& src_literal,
- tensorflow::gtl::ArraySlice<int64> src_base,
- tensorflow::gtl::ArraySlice<int64> dest_base,
- tensorflow::gtl::ArraySlice<int64> copy_size);
-
- // Utility structure which is used to create the optimal configuration for
- // a ShapeUtil::ForEachIndex() scan across two literals.
- struct StrideConfig {
- StrideConfig(const Shape& source_shape, const Shape& dest_shape,
- tensorflow::gtl::ArraySlice<int64> dimensions);
-
- // The dimensions of the stride operation. Essentially every dimension
- // will be iterated from base[i] to base[i]+dimensions[i], in step[i]
- // steps.
- tensorflow::gtl::ArraySlice<int64> dimensions;
- DimensionVector base;
- DimensionVector step;
- int64 minor_dimension = 0;
- // The size of the strides for source and destination. One of the two
- // (the one looping through its most minor dimension) will be 1, while
- // the other will be the stride size at the dimension matching the other
- // shape most minor dimension being scanned.
- int64 dest_stride = 1;
- int64 source_stride = 1;
- // The size of the inner loop on the most minor dimension.
- int64 minor_loop_size = 1;
- };
-
- // Literal class always owns the shape. The parent class borrows this shape.
- std::unique_ptr<Shape> shape_;
-
- Piece* root_piece_ = nullptr;
-
- // Implementation details shared between Populate() and PopulateParallel()
- template <typename NativeT, typename FnType>
- Status PopulateInternal(const FnType& generator, bool parallel);
-
- // Deallocate the buffers held by this literal.
- void DeallocateBuffers();
-
- friend class LiteralBase;
-};
-std::ostream& operator<<(std::ostream& out, const Literal& literal);
-
-// A read-only view of a Literal. A LiteralSlice contains pointers to shape and
-// literal buffers always owned by others.
-class LiteralSlice : public LiteralBase {
- public:
- LiteralSlice() : LiteralBase() {}
-
- // Implicit conversion constructors.
- LiteralSlice(const LiteralBase& literal);
- LiteralSlice(const LiteralBase& literal, const ShapeIndex& view_root);
-
- private:
- const Piece& root_piece() const override { return *root_piece_; };
-
- const Piece* root_piece_; // Not owned.
-};
-
-// A read-only Literal where the underlying buffers are never owned by this
-// class.
-class BorrowingLiteral : public LiteralBase {
- public:
- BorrowingLiteral() : LiteralBase() {}
-
- // 'src_buf_ptr' is not owned by this class and must outlive the
- // lifetime of this class. It points to an appropirately sized buffer with
- // data interpretered as indicated by 'shape'.
- // This constructor is only used for array shapes.
- BorrowingLiteral(const char* src_buf_ptr, const Shape& shape);
- // Similar as above, except to be used for constructing non-nested tuples.
- BorrowingLiteral(tensorflow::gtl::ArraySlice<const char*> src_buf_ptrs,
- const Shape& shape);
- // TODO(b/79707221): adding constructors for nested tuples as well.
-
- private:
- // Recursively builds the subtree for the given piece and sets the subshapes
- // of the given piece with the given shape.
- void BuildPieceSubtree(const Shape& shape, Piece* piece);
-
- // Accessor for the root piece of this literal.
- const Piece& root_piece() const override { return root_piece_; };
- Piece root_piece_;
-
- // Shape of this literal. Stored as unique_ptr so such that the (default)
- // move construction of this class would be trivially correct: the pointer to
- // Shape root_piece_ stores will still point to the correct address.
- std::unique_ptr<Shape> shape_;
};
-template <typename NativeT>
-tensorflow::gtl::ArraySlice<NativeT> LiteralBase::Piece::data() const {
- CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
- CHECK_EQ(subshape().element_type(),
- primitive_util::NativeToPrimitiveType<NativeT>())
- << "Attempting to access "
- << PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
- << " type, but literal element type is "
- << PrimitiveType_Name(subshape().element_type());
- return tensorflow::gtl::ArraySlice<NativeT>(
- reinterpret_cast<const NativeT*>(buffer()), element_count());
-}
-
-template <typename NativeT>
-tensorflow::gtl::MutableArraySlice<NativeT> LiteralBase::Piece::data() {
- CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
- CHECK_EQ(subshape().element_type(),
- primitive_util::NativeToPrimitiveType<NativeT>())
- << "Attempting to access "
- << PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
- << " type, but literal element type is "
- << PrimitiveType_Name(subshape().element_type());
- return tensorflow::gtl::MutableArraySlice<NativeT>(
- reinterpret_cast<NativeT*>(buffer()), element_count());
-}
-
-template <typename NativeT>
-NativeT LiteralBase::Piece::Get(
- tensorflow::gtl::ArraySlice<int64> multi_index) const {
- CHECK(LayoutUtil::IsDenseArray(subshape()));
- return data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
- subshape(), multi_index)];
-}
-
-template <typename NativeT>
-void LiteralBase::Piece::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
- NativeT value) {
- CHECK(LayoutUtil::IsDenseArray(subshape()));
- data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
- subshape(), multi_index)] = value;
-}
-
-template <typename NativeT>
-tensorflow::gtl::ArraySlice<NativeT> LiteralBase::data(
- const ShapeIndex& shape_index) const {
- return piece(shape_index).data<NativeT>();
-}
-
-template <typename NativeT>
-tensorflow::gtl::MutableArraySlice<NativeT> Literal::data(
- const ShapeIndex& shape_index) {
- return piece(shape_index).data<NativeT>();
-}
-
-template <typename NativeT>
-inline NativeT LiteralBase::Get(tensorflow::gtl::ArraySlice<int64> multi_index,
- const ShapeIndex& shape_index) const {
- return piece(shape_index).Get<NativeT>(multi_index);
-}
-
-template <typename NativeT>
-inline NativeT LiteralBase::Get(
- tensorflow::gtl::ArraySlice<int64> multi_index) const {
- return root_piece().Get<NativeT>(multi_index);
-}
-
-template <typename NativeT>
-inline void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
- const ShapeIndex& shape_index, NativeT value) {
- return piece(shape_index).Set<NativeT>(multi_index, value);
-}
-
-template <typename NativeT>
-inline void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
- NativeT value) {
- return root_piece().Set<NativeT>(multi_index, value);
-}
+std::ostream& operator<<(std::ostream& out, const Literal& literal);
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> Literal::CreateR0(NativeT value) {
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR0(NativeT value) {
auto literal = MakeUnique<Literal>(ShapeUtil::MakeShape(
primitive_util::NativeToPrimitiveType<NativeT>(), {}));
literal->Set({}, value);
@@ -1195,7 +334,7 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> Literal::CreateR1(
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR1(
tensorflow::gtl::ArraySlice<NativeT> values) {
auto literal = MakeUnique<Literal>(
ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<NativeT>(),
@@ -1205,7 +344,7 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> Literal::CreateR2WithLayout(
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2WithLayout(
std::initializer_list<std::initializer_list<NativeT>> values,
const Layout& layout) {
auto literal = MakeUnique<Literal>(ShapeUtil::MakeShapeWithLayout(
@@ -1218,13 +357,13 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> Literal::CreateR2(
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2(
std::initializer_list<std::initializer_list<NativeT>> values) {
return CreateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2());
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> Literal::CreateR3WithLayout(
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3WithLayout(
std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
values,
const Layout& layout) {
@@ -1249,14 +388,14 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> Literal::CreateR3(
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3(
std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
values) {
return CreateR3WithLayout(values, LayoutUtil::GetDefaultLayoutForR3());
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> Literal::CreateR4WithLayout(
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4WithLayout(
std::initializer_list<std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>>
values,
@@ -1287,7 +426,7 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> Literal::CreateSparse(
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateSparse(
tensorflow::gtl::ArraySlice<int64> dimensions, SparseIndexArray indices,
tensorflow::gtl::ArraySlice<NativeT> values, bool sort) {
int64 num_elements = values.size();
@@ -1302,7 +441,7 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> Literal::CreateR4(
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4(
std::initializer_list<std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>>
values) {
@@ -1310,7 +449,7 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> Literal::CreateFromArrayWithLayout(
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateFromArrayWithLayout(
const Array<NativeT>& values, const Layout& layout) {
auto literal = MakeUnique<Literal>(ShapeUtil::MakeShapeWithLayout(
primitive_util::NativeToPrimitiveType<NativeT>(), values.dimensions(),
@@ -1320,38 +459,40 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> Literal::CreateFromArray(
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateFromArray(
const Array<NativeT>& values) {
return CreateFromArrayWithLayout(
values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions()));
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> Literal::CreateR2FromArray2DWithLayout(
- const Array2D<NativeT>& values, const Layout& layout) {
+/* static */ std::unique_ptr<Literal>
+LiteralUtil::CreateR2FromArray2DWithLayout(const Array2D<NativeT>& values,
+ const Layout& layout) {
return CreateFromArrayWithLayout(values, layout);
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> Literal::CreateR2FromArray2D(
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2FromArray2D(
const Array2D<NativeT>& values) {
return CreateFromArray(values);
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> Literal::CreateR3FromArray3DWithLayout(
- const Array3D<NativeT>& values, const Layout& layout) {
+/* static */ std::unique_ptr<Literal>
+LiteralUtil::CreateR3FromArray3DWithLayout(const Array3D<NativeT>& values,
+ const Layout& layout) {
return CreateFromArrayWithLayout(values, layout);
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> Literal::CreateR3FromArray3D(
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3FromArray3D(
const Array3D<NativeT>& values) {
return CreateFromArray(values);
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> Literal::CreateR3Projected(
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3Projected(
std::initializer_list<std::initializer_list<NativeT>> values,
int64 projection) {
int64 dim0_size = projection;
@@ -1376,7 +517,7 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> Literal::CreateR4Projected(
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4Projected(
std::initializer_list<std::initializer_list<NativeT>> values,
int64 projection_p, int64 projection_z) {
int64 dim0_size = projection_p;
@@ -1404,49 +545,21 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> Literal::CreateR4FromArray4D(
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4FromArray4D(
const Array4D<NativeT>& values) {
return CreateFromArray(values);
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> Literal::CreateR4FromArray4DWithLayout(
- const Array4D<NativeT>& values, const Layout& layout) {
+/* static */ std::unique_ptr<Literal>
+LiteralUtil::CreateR4FromArray4DWithLayout(const Array4D<NativeT>& values,
+ const Layout& layout) {
return CreateFromArrayWithLayout(values, layout);
}
-template <typename NativeT>
-NativeT LiteralBase::GetFirstElement() const {
- return data<NativeT>().at(0);
-}
-
-template <typename NativeT>
-NativeT LiteralBase::GetSparseElement(int64 sparse_element_number,
- const ShapeIndex& shape_index) const {
- CHECK(
- LayoutUtil::IsSparseArray(ShapeUtil::GetSubshape(shape(), shape_index)));
- return data<NativeT>(shape_index)[sparse_element_number];
-}
-
-template <typename NativeT>
-void Literal::AppendSparseElement(
- tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value,
- const ShapeIndex& shape_index) {
- Piece& p = piece(shape_index);
- const Shape& subshape = p.subshape();
- CHECK(LayoutUtil::IsSparseArray(subshape));
- int64 rank = ShapeUtil::Rank(subshape);
- CHECK_EQ(multi_index.size(), rank);
- int64 last_element = p.sparse_indices()->index_count();
- CHECK_LT(last_element, LayoutUtil::MaxSparseElements(subshape.layout()));
- p.sparse_indices()->Append(multi_index);
- CHECK_LT(last_element, p.data<NativeT>().size());
- p.data<NativeT>()[last_element] = value;
-}
-
// Returns an identity matrix (rank 2) with the given row and column count.
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> Literal::MakeIdentityR2(int64 size) {
+/* static */ std::unique_ptr<Literal> LiteralUtil::MakeIdentityR2(int64 size) {
Array2D<NativeT> array(size, size, 0);
for (int64 i = 0; i < size; ++i) {
array(i, i) = 1;
@@ -1455,174 +568,8 @@ template <typename NativeT>
}
template <typename NativeT>
-void LiteralBase::EachCell(
- std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
- NativeT value)>
- per_cell) const {
- if (ShapeUtil::IsZeroElementArray(shape())) {
- return;
- }
- std::vector<int64> indices(ShapeUtil::Rank(shape()), 0);
- do {
- per_cell(indices, Get<NativeT>(indices));
- } while (IndexUtil::BumpIndices(shape(), &indices));
-}
-
-template <typename NativeT>
-inline void Literal::PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values) {
- CHECK(ShapeUtil::IsArray(shape()));
- CHECK_EQ(ShapeUtil::Rank(shape()), 1);
- CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size());
- CHECK_EQ(shape().element_type(),
- primitive_util::NativeToPrimitiveType<NativeT>());
- for (int64 i = 0; i < values.size(); ++i) {
- Set({i}, values[i]);
- }
-}
-
-template <typename NativeT>
-void Literal::PopulateR2(
- std::initializer_list<std::initializer_list<NativeT>> values) {
- CHECK(ShapeUtil::IsArray(shape()));
- CHECK_EQ(ShapeUtil::Rank(shape()), 2);
- CHECK_EQ(shape().element_type(),
- primitive_util::NativeToPrimitiveType<NativeT>());
-
- const int64 dim0_size = values.size();
- const int64 dim1_size = values.begin()->size();
- CHECK_EQ(dim0_size, shape().dimensions(0));
- CHECK_EQ(dim1_size, shape().dimensions(1));
-
- int64 dim0 = 0;
- for (auto inner_list : values) {
- int64 dim1 = 0;
- for (auto value : inner_list) {
- Set({dim0, dim1}, value);
- ++dim1;
- }
- CHECK_EQ(dim1_size, dim1);
- ++dim0;
- }
-}
-
-template <typename NativeT>
-void Literal::PopulateFromArray(const Array<NativeT>& values) {
- CHECK(ShapeUtil::IsArray(shape()));
- CHECK_EQ(shape().element_type(),
- primitive_util::NativeToPrimitiveType<NativeT>());
- CHECK_EQ(ShapeUtil::Rank(shape()), values.num_dimensions());
- for (int dim = 0; dim < values.num_dimensions(); ++dim) {
- CHECK_EQ(values.dim(dim), shape().dimensions(dim));
- }
- values.Each([this](tensorflow::gtl::ArraySlice<int64> indices,
- NativeT value) { this->Set(indices, value); });
-}
-
-template <typename NativeT>
-void Literal::PopulateR2FromArray2D(const Array2D<NativeT>& values) {
- PopulateFromArray(values);
-}
-
-template <typename NativeT>
-void Literal::PopulateR3FromArray3D(const Array3D<NativeT>& values) {
- PopulateFromArray(values);
-}
-
-template <typename NativeT>
-void Literal::PopulateR4FromArray4D(const Array4D<NativeT>& values) {
- PopulateFromArray(values);
-}
-
-template <typename NativeT>
-void Literal::PopulateSparse(SparseIndexArray indices,
- tensorflow::gtl::ArraySlice<NativeT> values,
- bool sort) {
- CHECK(LayoutUtil::IsSparseArray(shape()));
- int rank = ShapeUtil::Rank(shape());
- CHECK_EQ(indices.rank(), rank);
- int64 max_elements = LayoutUtil::MaxSparseElements(shape().layout());
- CHECK_LE(indices.max_indices(), max_elements);
- int64 num_elements = values.size();
- CHECK_LE(num_elements, max_elements);
- CHECK_EQ(num_elements, indices.index_count());
- auto root_data = root_piece().data<NativeT>();
- // Piece::data() returns an ArraySlice of size equal to the number of indices
- // in the SparseIndexArray. So there is no need to adjust the size of the data
- // here. It is enough to just copy the incoming values into the data buffer.
- std::copy(values.begin(), values.end(), root_data.begin());
- *this->root_piece().sparse_indices() = std::move(indices);
- if (sort) {
- auto root_data = this->root_piece().data<NativeT>();
- this->root_piece().sparse_indices()->SortWithValues(root_data);
- }
- DCHECK(this->root_piece().sparse_indices()->Validate(shape()));
-}
-
-template <typename NativeT, typename FnType>
-Status Literal::PopulateInternal(const FnType& generator, bool parallel) {
- const Shape& this_shape = shape();
- const int64 rank = ShapeUtil::Rank(this_shape);
- TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape));
- TF_RET_CHECK(this_shape.element_type() ==
- primitive_util::NativeToPrimitiveType<NativeT>());
- tensorflow::gtl::MutableArraySlice<NativeT> literal_data = data<NativeT>();
- if (rank > 0) {
- StrideConfig stride_config(this_shape, this_shape,
- AsInt64Slice(this_shape.dimensions()));
- int64 minor_dimension_size =
- ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension);
-
- auto init_function = [&](tensorflow::gtl::ArraySlice<int64> indexes) {
- DimensionVector minor_scan_indexes(rank, 0);
- const int64 index =
- IndexUtil::MultidimensionalIndexToLinearIndex(shape(), indexes);
- std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin());
- for (int64 i = 0; i < minor_dimension_size; ++i) {
- minor_scan_indexes[stride_config.minor_dimension] = i;
- literal_data.at(index + i) = generator(minor_scan_indexes);
- }
- };
- if (parallel) {
- ShapeUtil::ForEachIndexParallel(this_shape, stride_config.base,
- stride_config.dimensions,
- stride_config.step, init_function);
- } else {
- ShapeUtil::ForEachIndex(
- this_shape, stride_config.base, stride_config.dimensions,
- stride_config.step,
- [&init_function](tensorflow::gtl::ArraySlice<int64> indexes) {
- init_function(indexes);
- return true;
- });
- }
- } else {
- // For scalars.
- literal_data.at(0) = generator({});
- }
- return Status::OK();
-}
-template <typename NativeT, typename FnType>
-Status Literal::Populate(const FnType& generator) {
- return PopulateInternal<NativeT>(generator, /*parallel=*/false);
-}
-
-template <typename NativeT, typename FnType>
-Status Literal::PopulateParallel(const FnType& generator) {
- return PopulateInternal<NativeT>(generator, /*parallel=*/true);
-}
-
-template <typename NativeT>
-void Literal::PopulateWithValue(NativeT value) {
- CHECK(ShapeUtil::IsArray(shape()));
- CHECK_EQ(shape().element_type(),
- primitive_util::NativeToPrimitiveType<NativeT>());
- for (NativeT& element : data<NativeT>()) {
- element = value;
- }
-}
-
-template <typename NativeT>
-/* static */ std::unique_ptr<Literal> Literal::CreateFullWithDescendingLayout(
+/* static */ std::unique_ptr<Literal>
+LiteralUtil::CreateFullWithDescendingLayout(
tensorflow::gtl::ArraySlice<int64> dimensions, NativeT value) {
auto literal = MakeUnique<Literal>(ShapeUtil::MakeShapeWithDescendingLayout(
primitive_util::NativeToPrimitiveType<NativeT>(), dimensions));
@@ -1630,44 +577,9 @@ template <typename NativeT>
return literal;
}
-template <typename NativeT>
-std::unique_ptr<Literal> LiteralBase::Replicate(int64 times) const {
- DimensionVector bounds = {times};
- bounds.reserve(shape().dimensions_size() + 1);
- for (int64 bound : shape().dimensions()) {
- bounds.push_back(bound);
- }
- auto literal =
- MakeUnique<Literal>(ShapeUtil::MakeShape(shape().element_type(), bounds));
- int64 elements = ShapeUtil::ElementsIn(literal->shape());
- if (elements == 0) {
- return literal;
- }
-
- DimensionVector output_indices(bounds.size(), 0);
- tensorflow::gtl::ArraySlice<int64> input_indices = output_indices;
- input_indices.remove_prefix(1);
-
- bool done = false;
- while (!done) {
- const auto element = Get<NativeT>(input_indices);
- literal->Set<NativeT>(output_indices, element);
-
- done = true;
- for (int n = 0; n < output_indices.size(); ++n) {
- ++output_indices[n];
- if (output_indices[n] < bounds[n]) {
- done = false;
- break;
- }
- output_indices[n] = 0;
- }
- }
- return literal;
-}
-
template <PrimitiveType type, typename T>
-/* static */ StatusOr<std::unique_ptr<Literal>> Literal::CreateRandomLiteral(
+/* static */ StatusOr<std::unique_ptr<Literal>>
+LiteralUtil::CreateRandomLiteral(
const Shape& shape,
const std::function<T(tensorflow::gtl::ArraySlice<int64>)>& generator) {
using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
@@ -1681,8 +593,9 @@ template <PrimitiveType type, typename T>
}
template <PrimitiveType type, typename E, typename T>
-/* static */ StatusOr<std::unique_ptr<Literal>> Literal::CreateRandomLiteral(
- const Shape& shape, E* engine, T mean, T stddev) {
+/* static */ StatusOr<std::unique_ptr<Literal>>
+LiteralUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean,
+ T stddev) {
using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
std::normal_distribution<NativeT> generator(mean, stddev);
return CreateRandomLiteral<type, NativeT>(
@@ -1692,8 +605,8 @@ template <PrimitiveType type, typename E, typename T>
}
template <PrimitiveType type, typename T>
-/* static */ StatusOr<std::unique_ptr<Literal>> Literal::CreateRandomLiteral(
- const Shape& shape, T mean, T stddev) {
+/* static */ StatusOr<std::unique_ptr<Literal>>
+LiteralUtil::CreateRandomLiteral(const Shape& shape, T mean, T stddev) {
std::minstd_rand0 engine;
return CreateRandomLiteral<type>(shape, &engine, mean, stddev);
}
diff --git a/tensorflow/compiler/xla/packed_literal_reader.cc b/tensorflow/compiler/xla/packed_literal_reader.cc
index 857aae0a79..6b7fd10d63 100644
--- a/tensorflow/compiler/xla/packed_literal_reader.cc
+++ b/tensorflow/compiler/xla/packed_literal_reader.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include <utility>
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
diff --git a/tensorflow/compiler/xla/packed_literal_reader.h b/tensorflow/compiler/xla/packed_literal_reader.h
index 45a9fe0127..98dccaa9a2 100644
--- a/tensorflow/compiler/xla/packed_literal_reader.h
+++ b/tensorflow/compiler/xla/packed_literal_reader.h
@@ -18,7 +18,7 @@ limitations under the License.
#include <memory>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD
index 22cc4e2436..fe346f9956 100644
--- a/tensorflow/compiler/xla/python/BUILD
+++ b/tensorflow/compiler/xla/python/BUILD
@@ -33,6 +33,7 @@ cc_library(
srcs = ["numpy_bridge.cc"],
hdrs = ["numpy_bridge.h"],
deps = [
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:xla_data_proto",
@@ -70,7 +71,7 @@ tf_py_wrap_cc(
deps = [
":local_computation_builder",
":numpy_bridge",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:cpu_plugin",
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i
index c44e69e615..afdea88cb7 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.i
+++ b/tensorflow/compiler/xla/python/local_computation_builder.i
@@ -109,7 +109,7 @@ limitations under the License.
// Must be included first
#include "tensorflow/python/lib/core/numpy.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc
index 68648a3a17..71351abd59 100644
--- a/tensorflow/compiler/xla/python/numpy_bridge.cc
+++ b/tensorflow/compiler/xla/python/numpy_bridge.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/python/numpy_bridge.h"
+#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/platform/logging.h"
@@ -374,7 +375,7 @@ StatusOr<std::unique_ptr<Literal>> XlaLiteralFromPyObject(PyObject* o) {
TF_ASSIGN_OR_RETURN(auto literal, XlaLiteralFromPyObject(element));
elements.push_back(std::move(literal));
}
- return Literal::MakeTupleOwned(std::move(elements));
+ return LiteralUtil::MakeTupleOwned(std::move(elements));
} else if (PyArray_Check(o)) {
PyArrayObject* py_array = reinterpret_cast<PyArrayObject*>(o);
int rank = PyArray_NDIM(py_array);
@@ -383,7 +384,7 @@ StatusOr<std::unique_ptr<Literal>> XlaLiteralFromPyObject(PyObject* o) {
dimensions[i] = PyArray_DIM(py_array, i);
}
int np_type = PyArray_TYPE(py_array);
- auto literal = Literal::CreateFromDimensions(
+ auto literal = LiteralUtil::CreateFromDimensions(
NumpyTypeToPrimitiveType(np_type), dimensions);
TF_RETURN_IF_ERROR(
CopyNumpyArrayToLiteral(np_type, py_array, literal.get()));
diff --git a/tensorflow/compiler/xla/python/numpy_bridge.h b/tensorflow/compiler/xla/python/numpy_bridge.h
index 64f0aae0f9..a67c93a4fb 100644
--- a/tensorflow/compiler/xla/python/numpy_bridge.h
+++ b/tensorflow/compiler/xla/python/numpy_bridge.h
@@ -25,7 +25,7 @@ limitations under the License.
#include <algorithm>
#include <memory>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/python/lib/core/numpy.h"
diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py
index 27aee634ba..e2b6eaa096 100644
--- a/tensorflow/compiler/xla/python/xla_client.py
+++ b/tensorflow/compiler/xla/python/xla_client.py
@@ -461,14 +461,16 @@ class LocalComputation(object):
if self.is_compiled:
raise ValueError('Attempt to compile a compiled local XLA computation.')
+ result_shape = _wrap_shape(self.c_local_computation.GetReturnValueShape())
+
if layout_fn:
argument_shapes = [
shape.map_leaves(layout_fn) for shape in argument_shapes
]
- result_shape = _wrap_shape(self.c_local_computation.GetReturnValueShape())
result_shape = result_shape.map_leaves(layout_fn)
- compile_options = compile_options or CompileOptions()
- compile_options.result_shape = result_shape
+
+ compile_options = compile_options or CompileOptions()
+ compile_options.result_shape = result_shape
return LocalComputation(
self.c_local_computation.Compile(argument_shapes, compile_options),
is_compiled=True)
diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc
index c289c84cff..6397f1f479 100644
--- a/tensorflow/compiler/xla/reference_util.cc
+++ b/tensorflow/compiler/xla/reference_util.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <utility>
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -510,8 +511,8 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
std::pair<int64, int64> lhs_dilation, std::pair<int64, int64> rhs_dilation,
ConvolutionDimensionNumbers dnums) {
HloComputation::Builder b("ConvArray4DGeneralDimensionDilated");
- auto lhs_literal = Literal::CreateR4FromArray4D<float>(lhs);
- auto rhs_literal = Literal::CreateR4FromArray4D<float>(rhs);
+ auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(lhs);
+ auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(rhs);
std::array<int64, 2> ordered_kernel_strides;
std::array<int64, 2> ordered_input_dimensions;
diff --git a/tensorflow/compiler/xla/reference_util_test.cc b/tensorflow/compiler/xla/reference_util_test.cc
index 9da9bc60a2..8091bed499 100644
--- a/tensorflow/compiler/xla/reference_util_test.cc
+++ b/tensorflow/compiler/xla/reference_util_test.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/padding.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
@@ -53,7 +53,7 @@ class ReferenceUtilTest : public ::testing::Test {
TEST_F(ReferenceUtilTest, TransposeArray2D) {
auto result = ReferenceUtil::TransposeArray2D(*matrix_);
- auto actual_literal = Literal::CreateR2FromArray2D(*result);
+ auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
LiteralTestUtil::ExpectR2Near<float>({{1.f, 4.f}, {2.f, 5.f}, {3.f, 6.f}},
*actual_literal, ErrorSpec(0.0001));
}
@@ -65,7 +65,7 @@ TEST_F(ReferenceUtilTest, MatmulArray2D) {
{11.f, 12.f},
});
auto result = ReferenceUtil::MatmulArray2D(*matrix_, rhs);
- auto actual_literal = Literal::CreateR2FromArray2D(*result);
+ auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
LiteralTestUtil::ExpectR2Near<float>({{58.f, 64.f}, {139.f, 154.f}},
*actual_literal, ErrorSpec(0.0001));
}
@@ -73,7 +73,7 @@ TEST_F(ReferenceUtilTest, MatmulArray2D) {
TEST_F(ReferenceUtilTest, ReduceToColArray2D) {
auto add = [](float lhs, float rhs) { return lhs + rhs; };
auto result = ReferenceUtil::ReduceToColArray2D(*matrix_, 0.0f, add);
- auto actual_literal = Literal::CreateR1<float>(*result);
+ auto actual_literal = LiteralUtil::CreateR1<float>(*result);
LiteralTestUtil::ExpectR1Near<float>({6.f, 15.f}, *actual_literal,
ErrorSpec(0.0001));
}
@@ -81,13 +81,13 @@ TEST_F(ReferenceUtilTest, ReduceToColArray2D) {
TEST_F(ReferenceUtilTest, ReduceToRowArray2D) {
auto add = [](float lhs, float rhs) { return lhs + rhs; };
auto result = ReferenceUtil::ReduceToRowArray2D(*matrix_, 0.0f, add);
- auto actual_literal = Literal::CreateR1<float>(*result);
+ auto actual_literal = LiteralUtil::CreateR1<float>(*result);
LiteralTestUtil::ExpectR1Near<float>({5.f, 7.f, 9.f}, *actual_literal,
ErrorSpec(0.0001));
}
TEST_F(ReferenceUtilTest, Reduce4Dto1DZeroSizedArray) {
- auto result = Literal::CreateR1<float>(ReferenceUtil::Reduce4DTo1D(
+ auto result = LiteralUtil::CreateR1<float>(ReferenceUtil::Reduce4DTo1D(
Array4D<float>(1, 0, 1, 1), /*init=*/0, /*dims=*/{0, 1, 2},
[](float a, float b) { return a + b; }));
LiteralTestUtil::ExpectR1Equal<float>({0}, *result);
@@ -96,7 +96,7 @@ TEST_F(ReferenceUtilTest, Reduce4Dto1DZeroSizedArray) {
TEST_F(ReferenceUtilTest, MapArray2D) {
auto identity = [](float value) { return log(exp(value)); };
auto result = ReferenceUtil::MapArray2D(*matrix_, identity);
- auto actual_literal = Literal::CreateR2FromArray2D(*result);
+ auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
LiteralTestUtil::ExpectR2NearArray2D(*matrix_, *actual_literal,
ErrorSpec(0.0001));
}
@@ -106,7 +106,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray2D) {
return value + row + col;
};
auto result = ReferenceUtil::MapWithIndexArray2D(*matrix_, add_index);
- auto actual_literal = Literal::CreateR2FromArray2D(*result);
+ auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
LiteralTestUtil::ExpectR2Near<float>({{1.f, 3.f, 5.f}, {5.f, 7.f, 9.f}},
*actual_literal, ErrorSpec(0.0001));
}
@@ -117,7 +117,7 @@ TEST_F(ReferenceUtilTest, MapArray4D) {
input->FillWithMultiples(1.0f);
auto multiply_by_two = [](float value) { return 2 * value; };
auto result = ReferenceUtil::MapArray4D(*input, multiply_by_two);
- auto actual_literal = Literal::CreateR4FromArray4D(*result);
+ auto actual_literal = LiteralUtil::CreateR4FromArray4D(*result);
Array4D<float> expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5);
expected.FillWithMultiples(2.0f);
@@ -134,7 +134,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray4D) {
return value - (3 * 4 * 5 * plane + 4 * 5 * depth + 5 * height + width);
};
auto result = ReferenceUtil::MapWithIndexArray4D(*input, subtract_index);
- auto actual_literal = Literal::CreateR4FromArray4D(*result);
+ auto actual_literal = LiteralUtil::CreateR4FromArray4D(*result);
Array4D<float> expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5);
expected.Fill(0.0f);
@@ -144,7 +144,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray4D) {
TEST_F(ReferenceUtilTest, SliceArray2D) {
auto result = ReferenceUtil::Slice2D(*matrix_, {{0, 0}}, {{2, 2}}, {{1, 1}});
- auto actual_literal = Literal::CreateR2FromArray2D(*result);
+ auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
LiteralTestUtil::ExpectR2Near<float>({{1.f, 2.f}, {4.f, 5.f}},
*actual_literal, ErrorSpec(0.0001));
@@ -152,7 +152,7 @@ TEST_F(ReferenceUtilTest, SliceArray2D) {
TEST_F(ReferenceUtilTest, SliceStridedArray2D) {
auto result = ReferenceUtil::Slice2D(*matrix_, {{0, 0}}, {{2, 3}}, {{1, 2}});
- auto actual_literal = Literal::CreateR2FromArray2D(*result);
+ auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
LiteralTestUtil::ExpectR2Near<float>({{1.f, 3.f}, {4.f, 6.f}},
*actual_literal, ErrorSpec(0.0001));
@@ -164,7 +164,7 @@ TEST_F(ReferenceUtilTest, SliceArray3D) {
auto result =
ReferenceUtil::Slice3D(input, {{0, 0, 0}}, {{2, 2, 2}}, {{1, 1, 1}});
- auto actual_literal = Literal::CreateR3FromArray3D(*result);
+ auto actual_literal = LiteralUtil::CreateR3FromArray3D(*result);
LiteralTestUtil::ExpectR3Near<float>(
{{{0.f, 1.f}, {4.f, 5.f}}, {{12.f, 13.f}, {16.f, 17.f}}}, *actual_literal,
@@ -177,7 +177,7 @@ TEST_F(ReferenceUtilTest, SliceStridedArray3D) {
auto result =
ReferenceUtil::Slice3D(input, {{0, 0, 0}}, {{2, 3, 4}}, {{1, 2, 2}});
- auto actual_literal = Literal::CreateR3FromArray3D(*result);
+ auto actual_literal = LiteralUtil::CreateR3FromArray3D(*result);
LiteralTestUtil::ExpectR3Near<float>(
{{{0.f, 2.f}, {8.f, 10.f}}, {{12.f, 14.f}, {20.f, 22.f}}},
@@ -190,7 +190,7 @@ TEST_F(ReferenceUtilTest, SliceArray4D) {
auto result = ReferenceUtil::Slice4D(input, {{1, 0, 0, 0}}, {{2, 2, 2, 2}},
{{1, 1, 1, 1}});
- auto actual_literal = Literal::CreateR4FromArray4D(*result);
+ auto actual_literal = LiteralUtil::CreateR4FromArray4D(*result);
LiteralTestUtil::ExpectR4Near<float>(
{{{{60.f, 61.f}, {65.f, 66.f}}, {{80.f, 81.f}, {85.f, 86.f}}}},
@@ -203,7 +203,7 @@ TEST_F(ReferenceUtilTest, SliceStridedArray4D) {
auto result = ReferenceUtil::Slice4D(input, {{1, 0, 0, 0}}, {{2, 3, 4, 5}},
{{1, 2, 2, 2}});
- auto actual_literal = Literal::CreateR4FromArray4D(*result);
+ auto actual_literal = LiteralUtil::CreateR4FromArray4D(*result);
LiteralTestUtil::ExpectR4Near<float>(
{{{{60.f, 62.f, 64.f}, {70.f, 72.f, 74.f}},
@@ -218,7 +218,7 @@ TEST_F(ReferenceUtilTest, ConvArray3DWithSamePadding) {
ReferenceUtil::ConvArray3D(input, weights, 1, Padding::kSame);
Array3D<float> expected = {{{17, 28, 39, 20}}};
- auto actual_literal = Literal::CreateR3FromArray3D(*actual);
+ auto actual_literal = LiteralUtil::CreateR3FromArray3D(*actual);
LiteralTestUtil::ExpectR3NearArray3D<float>(expected, *actual_literal,
ErrorSpec(0.0001));
@@ -231,7 +231,7 @@ TEST_F(ReferenceUtilTest, ConvArray3DWithValidPadding) {
ReferenceUtil::ConvArray3D(input, weights, 1, Padding::kValid);
Array3D<float> expected = {{{17, 28, 39}}};
- auto actual_literal = Literal::CreateR3FromArray3D(*actual);
+ auto actual_literal = LiteralUtil::CreateR3FromArray3D(*actual);
LiteralTestUtil::ExpectR3NearArray3D<float>(expected, *actual_literal,
ErrorSpec(0.0001));
@@ -266,7 +266,7 @@ TEST_F(ReferenceUtilTest, ConvWithSamePadding) {
}));
// clang-format on
- auto actual_literal = Literal::CreateR4FromArray4D(*actual);
+ auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
ErrorSpec(0.0001));
@@ -300,7 +300,7 @@ TEST_F(ReferenceUtilTest, ConvWithValidPadding) {
}));
// clang-format on
- auto actual_literal = Literal::CreateR4FromArray4D(*actual);
+ auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
ErrorSpec(0.0001));
@@ -356,7 +356,7 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithSamePadding) {
}});
// clang-format on
- auto actual_literal = Literal::CreateR4FromArray4D(*actual);
+ auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
ErrorSpec(0.0001));
@@ -409,7 +409,7 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithValidPadding) {
Array4D<float> expected({{{{2514, 2685}}}});
// clang-format on
- auto actual_literal = Literal::CreateR4FromArray4D(*actual);
+ auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
ErrorSpec(0.0001));
@@ -422,7 +422,7 @@ TEST_F(ReferenceUtilTest, ApplyElementwise2D) {
auto actual = ReferenceUtil::ApplyElementwise2D(
[](float x, float y, float z) { return 100 * x + 10 * y + z; }, a, b, c);
- auto actual_literal = Literal::CreateR2FromArray2D(*actual);
+ auto actual_literal = LiteralUtil::CreateR2FromArray2D(*actual);
LiteralTestUtil::ExpectR2Near({{300.f, 600.f}, {900.f, 1200.f}},
*actual_literal, ErrorSpec(0.0001));
}
diff --git a/tensorflow/compiler/xla/rpc/grpc_client_test.cc b/tensorflow/compiler/xla/rpc/grpc_client_test.cc
index f8414468bd..90efee50b4 100644
--- a/tensorflow/compiler/xla/rpc/grpc_client_test.cc
+++ b/tensorflow/compiler/xla/rpc/grpc_client_test.cc
@@ -97,7 +97,7 @@ TEST_F(GRPCClientTestBase, AxpyTenValues) {
1.85840735, -1.85840735, 2.28318531, -2.28318531, -6.42477796,
6.42477796, 10.56637061, -10.56637061, -14.70796327, 14.70796327};
std::unique_ptr<Literal> expected_literal =
- Literal::CreateR1<float>(expected);
+ LiteralUtil::CreateR1<float>(expected);
TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build());
TF_ASSERT_OK_AND_ASSIGN(auto result_literal, client_->ExecuteAndTransfer(
computation, {}, nullptr));
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index fe99f700d2..989bb759e3 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -136,7 +136,7 @@ cc_library(
":hlo_dce",
":hlo_pass",
":tuple_simplifier",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_tree",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
@@ -182,6 +182,7 @@ tf_cc_test(
name = "shape_inference_test",
srcs = ["shape_inference_test.cc"],
deps = [
+ ":hlo",
":shape_inference",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
@@ -227,6 +228,7 @@ cc_library(
":hlo",
":hlo_query",
":shape_inference",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
@@ -244,7 +246,7 @@ tf_cc_test(
deps = [
":hlo",
":hlo_evaluator",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:reference_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status",
@@ -294,6 +296,7 @@ cc_library(
":hlo_reachability",
":name_uniquer",
"//tensorflow/compiler/xla:array",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:protobuf_util",
"//tensorflow/compiler/xla:shape_tree",
@@ -396,6 +399,7 @@ tf_cc_test(
deps = [
":hlo_matchers",
":hlo_parser",
+ "//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
@@ -407,7 +411,7 @@ tf_cc_test(
deps = [
":hlo",
":hlo_parser",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:protobuf_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
@@ -424,7 +428,7 @@ tf_cc_test(
srcs = ["hlo_sharding_test.cc"],
deps = [
":hlo",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:protobuf_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
@@ -453,7 +457,7 @@ tf_cc_test(
srcs = ["call_graph_test.cc"],
deps = [
":call_graph",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:test",
@@ -487,6 +491,7 @@ cc_library(
hdrs = ["call_inliner.h"],
deps = [
":call_graph",
+ ":hlo_dce",
":hlo_pass",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:lib",
@@ -502,7 +507,7 @@ tf_cc_test(
":hlo",
":hlo_matchers",
":hlo_pass",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
@@ -521,7 +526,7 @@ tf_cc_test(
deps = [
":call_graph",
":flatten_call_graph",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:test",
@@ -797,7 +802,7 @@ cc_library(
hdrs = ["transfer_manager.h"],
deps = [
":shaped_buffer",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
@@ -960,7 +965,7 @@ tf_cc_test(
":hlo",
":hlo_ordering",
":hlo_scheduling",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
@@ -1038,7 +1043,7 @@ tf_cc_test(
":hlo_ordering",
":hlo_value",
":tuple_points_to_analysis",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -1121,7 +1126,7 @@ cc_library(
hdrs = ["hlo_query.h"],
deps = [
":hlo",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
],
)
@@ -1170,6 +1175,7 @@ cc_library(
deps = [
":hlo",
":shape_inference",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
@@ -1200,6 +1206,7 @@ cc_library(
deps = [
":hlo",
":hlo_pass",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@@ -1219,6 +1226,7 @@ cc_library(
":hlo_creation_utils",
":hlo_pass",
":while_util",
+ "//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
],
@@ -1232,8 +1240,9 @@ tf_cc_test(
":batchnorm_expander",
":hlo",
":hlo_matchers",
+ ":hlo_parser",
":hlo_pass",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
@@ -1255,6 +1264,7 @@ cc_library(
":hlo_pass",
":hlo_query",
":pattern_matcher",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@@ -1274,7 +1284,7 @@ tf_cc_test(
":hlo",
":hlo_matchers",
":hlo_pass",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
@@ -1310,7 +1320,7 @@ tf_cc_test(
":hlo",
":hlo_matchers",
":hlo_pass",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
@@ -1345,7 +1355,7 @@ cc_library(
":call_inliner",
":hlo",
":hlo_pass",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
@@ -1361,6 +1371,7 @@ tf_cc_test(
":conditional_simplifier",
":hlo",
":hlo_matchers",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
@@ -1420,7 +1431,7 @@ tf_cc_test(
deps = [
":defuser",
":hlo_matchers",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
],
@@ -1448,7 +1459,7 @@ tf_cc_test(
deps = [
":hlo_matchers",
":implicit_broadcast_remover",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
],
@@ -1490,7 +1501,7 @@ tf_cc_test(
":hlo",
":hlo_matchers",
":tuple_simplifier",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
@@ -1505,7 +1516,7 @@ cc_library(
hdrs = ["reshape_mover.h"],
deps = [
":hlo_pass",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:util",
@@ -1520,7 +1531,7 @@ tf_cc_test(
":hlo",
":hlo_matchers",
":reshape_mover",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
@@ -1555,7 +1566,7 @@ tf_cc_test(
":hlo",
":hlo_matchers",
":inliner",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:util",
@@ -1572,7 +1583,7 @@ cc_library(
hdrs = ["computation_placer.h"],
deps = [
"//tensorflow/compiler/xla:array2d",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:status_macros",
@@ -1604,7 +1615,7 @@ cc_library(
hdrs = ["generic_transfer_manager.h"],
deps = [
":transfer_manager",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
@@ -1695,7 +1706,7 @@ tf_cc_test(
deps = [
":hlo",
":hlo_matchers",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
@@ -1710,6 +1721,7 @@ tf_cc_binary(
deps = [
":hlo",
":hlo_graph_dumper",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
@@ -1724,7 +1736,7 @@ tf_cc_test(
srcs = ["hlo_module_test.cc"],
deps = [
":hlo",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:util",
@@ -1822,7 +1834,7 @@ tf_cc_test(
":hlo_matchers",
":hlo_ordering",
":instruction_fusion",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:test",
@@ -1859,7 +1871,7 @@ tf_cc_test(
deps = [
":hlo",
":hlo_liveness_analysis",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:test",
@@ -1920,7 +1932,7 @@ tf_cc_test(
":hlo_matchers",
":hlo_ordering",
":instruction_fusion",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
@@ -1955,6 +1967,7 @@ cc_library(
":hlo_dataflow_analysis",
":logical_buffer",
":logical_buffer_analysis",
+ "//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_tree",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
@@ -1973,6 +1986,7 @@ tf_cc_test(
":hlo_matchers",
":instruction_fusion",
":tuple_points_to_analysis",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
@@ -2044,7 +2058,7 @@ tf_cc_test(
":hlo_graph_dumper",
":hlo_matchers",
":hlo_runner",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
@@ -2108,6 +2122,7 @@ tf_cc_test(
srcs = ["hlo_verifier_test.cc"],
deps = [
":hlo",
+ ":hlo_parser",
":hlo_verifier",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
@@ -2169,6 +2184,7 @@ tf_cc_test(
deps = [
":hlo",
":hlo_dce",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
@@ -2189,7 +2205,7 @@ tf_cc_test(
deps = [
":hlo",
":hlo_module_dce",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
@@ -2213,7 +2229,7 @@ tf_cc_test(
":hlo",
":hlo_matchers",
":layout_assignment",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_layout",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
@@ -2272,7 +2288,7 @@ cc_library(
":hlo",
":hlo_domain_map",
":hlo_pass",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
@@ -2288,7 +2304,7 @@ tf_cc_test(
":hlo",
":hlo_cse",
":hlo_matchers",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
@@ -2310,7 +2326,7 @@ cc_library(
":hlo_evaluator",
":hlo_pass",
":hlo_query",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/core:lib",
@@ -2325,7 +2341,7 @@ tf_cc_test(
":hlo_constant_folding",
":hlo_matchers",
":hlo_pass",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
@@ -2363,6 +2379,20 @@ cc_library(
)
cc_library(
+ name = "hlo_domain_verifier",
+ srcs = ["hlo_domain_verifier.cc"],
+ hdrs = ["hlo_domain_verifier.h"],
+ deps = [
+ ":hlo",
+ ":hlo_domain_map",
+ ":hlo_graph_dumper",
+ ":hlo_pass",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
name = "hlo_domain_isolator",
srcs = ["hlo_domain_isolator.cc"],
hdrs = ["hlo_domain_isolator.h"],
@@ -2381,8 +2411,8 @@ cc_library(
hdrs = ["hlo_domain_remover.h"],
deps = [
":hlo",
- ":hlo_domain_isolator",
":hlo_domain_map",
+ ":hlo_domain_verifier",
":hlo_graph_dumper",
":hlo_pass",
"//tensorflow/compiler/xla:types",
@@ -2417,7 +2447,7 @@ cc_library(
":hlo_evaluator",
":hlo_pass",
":hlo_query",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/core:lib",
@@ -2552,7 +2582,7 @@ cc_library(
hdrs = ["hlo_tfgraph_builder.h"],
deps = [
":hlo",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:xla_proto",
"//tensorflow/core:framework",
@@ -2583,7 +2613,7 @@ cc_library(
":hlo_casting_utils",
":hlo_execution_profile",
":hlo_tfgraph_builder",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:window_util",
@@ -2601,6 +2631,7 @@ tf_cc_test(
deps = [
":hlo",
":hlo_graph_dumper",
+ "//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:xla_proto",
"//tensorflow/compiler/xla/tests:test_utils",
@@ -2632,7 +2663,7 @@ tf_cc_test(
":hlo_matchers",
":shape_inference",
":transpose_folding",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
@@ -2653,7 +2684,7 @@ cc_library(
deps = [
":hlo",
":hlo_pass",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:util",
@@ -2668,7 +2699,7 @@ tf_cc_test(
":hlo",
":shape_inference",
":zero_sized_hlo_elimination",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:test",
@@ -2828,6 +2859,7 @@ cc_library(
":hlo",
":hlo_creation_utils",
":tuple_util",
+ "//tensorflow/compiler/xla:literal_util",
"//tensorflow/core:lib",
],
)
@@ -2963,6 +2995,7 @@ cc_library(
":hlo",
":hlo_lexer",
":hlo_sharding_metadata",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 1ddeb27e40..af7728da54 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -195,7 +196,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) {
HloInstruction* zero =
computation_->AddInstruction(HloInstruction::CreateConstant(
- Literal::Zero(hlo->shape().element_type()).CloneToUnique()));
+ LiteralUtil::Zero(hlo->shape().element_type()).CloneToUnique()));
HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation();
Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape());
return computation_->AddInstruction(HloInstruction::CreateReduce(
@@ -537,8 +538,8 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) {
// If a literal is all the same element replace it with a scalar broadcast.
if (ShapeUtil::ElementsIn(constant->shape()) > 1 &&
constant->literal().IsAllFirst()) {
- std::unique_ptr<Literal> unique_scalar =
- MakeUnique<Literal>(constant->literal().GetFirstScalarLiteral());
+ std::unique_ptr<Literal> unique_scalar = MakeUnique<Literal>(
+ LiteralUtil::GetFirstScalarLiteral(constant->literal()));
HloInstruction* scalar = computation_->AddInstruction(
HloInstruction::CreateConstant(std::move(unique_scalar)));
return ReplaceWithNewInstruction(
@@ -1093,7 +1094,7 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
ShapeUtil::IsZeroElementArray(lhs->shape()) ||
ShapeUtil::IsZeroElementArray(rhs->shape())) {
auto zero = computation_->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(0.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f)));
return ReplaceWithNewInstruction(
dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {}));
}
@@ -1519,7 +1520,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
CHECK(Match(power, m::Power(m::Op(&lhs), m::Op(&rhs))));
if (IsAll(rhs, 0)) {
auto one = HloInstruction::CreateConstant(
- Literal::One(power->shape().element_type()).CloneToUnique());
+ LiteralUtil::One(power->shape().element_type()).CloneToUnique());
std::unique_ptr<HloInstruction> ones;
if (ShapeUtil::IsScalar(power->shape())) {
ones = std::move(one);
@@ -1554,7 +1555,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString();
if (IsAll(rhs, -1)) {
auto* one = computation_->AddInstruction(HloInstruction::CreateConstant(
- Literal::One(rhs->shape().element_type()).CloneToUnique()));
+ LiteralUtil::One(rhs->shape().element_type()).CloneToUnique()));
// Explicitly broadcast scalar 1 to the output shape, to avoid implicit
// broadcast in divide HLO as we are trying to eliminate implicit
@@ -2098,7 +2099,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
HloInstruction::CreateBroadcast(
convolution->shape(),
computation_->AddInstruction(HloInstruction::CreateConstant(
- Literal::Zero(convolution->shape().element_type())
+ LiteralUtil::Zero(convolution->shape().element_type())
.CloneToUnique())),
{}));
}
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index b733f6f59e..92bbcbd740 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include <utility>
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -60,7 +60,7 @@ TEST_F(AlgebraicSimplifierTest, AddZero) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32, "param0"));
HloInstruction* zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, zero));
@@ -79,7 +79,7 @@ TEST_F(AlgebraicSimplifierTest, TwoReducesToOne) {
HloComputation::Builder builder(TestName());
// Create add computation.
HloInstruction* zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
HloComputation* add_computation = nullptr;
{
HloComputation::Builder builder(TestName() + ".add");
@@ -119,7 +119,7 @@ TEST_F(AlgebraicSimplifierTest, AddConstOnLHS) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32, "param0"));
HloInstruction* constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f)));
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, constant, param0));
@@ -140,9 +140,9 @@ TEST_F(AlgebraicSimplifierTest, AddReassociateMergeConstants) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32, "param0"));
HloInstruction* constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f)));
HloInstruction* constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(3.14159f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.14159f)));
HloInstruction* add1 = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, constant1));
@@ -165,7 +165,7 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r2f32, "param0"));
HloInstruction* zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
HloInstruction* bcast = builder.AddInstruction(
HloInstruction::CreateBroadcast(r2f32, zero, {0, 1}));
builder.AddInstruction(
@@ -200,7 +200,7 @@ TEST_F(AlgebraicSimplifierTest, InlineTrivialMap) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r2f32, "param0"));
HloInstruction* zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
builder.AddInstruction(HloInstruction::CreateMap(
r2f32,
{param0, builder.AddInstruction(
@@ -223,7 +223,7 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r2f32, "param0"));
HloInstruction* zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>({0, 0, 0})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({0, 0, 0})));
HloInstruction* bcast =
builder.AddInstruction(HloInstruction::CreateBroadcast(r2f32, zero, {1}));
builder.AddInstruction(
@@ -242,7 +242,7 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) {
TEST_F(AlgebraicSimplifierTest, ConstantToBroadcast) {
HloComputation::Builder builder(TestName());
builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({3.14f, 3.14f, 3.14f})));
+ LiteralUtil::CreateR1<float>({3.14f, 3.14f, 3.14f})));
auto computation = module().AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
@@ -258,7 +258,7 @@ TEST_F(AlgebraicSimplifierTest, ConstantToBroadcast) {
TEST_F(AlgebraicSimplifierTest, ConstantNotToBroadcast) {
HloComputation::Builder builder(TestName());
builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({3.14, 3.14, 4})));
+ LiteralUtil::CreateR1<float>({3.14, 3.14, 4})));
auto computation = module().AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
@@ -277,7 +277,7 @@ TEST_F(AlgebraicSimplifierTest, SubZero) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32, "param0"));
HloInstruction* zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kSubtract, param0, zero));
@@ -298,7 +298,7 @@ TEST_F(AlgebraicSimplifierTest, SubConstCanonicalization) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32, "param0"));
HloInstruction* constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
builder.AddInstruction(HloInstruction::CreateBinary(
r0f32, HloOpcode::kSubtract, param0, constant));
@@ -493,7 +493,7 @@ TEST_F(AlgebraicSimplifierTest, DivideByConstant) {
HloInstruction::CreateParameter(0, r1f32, "param0"));
HloInstruction* constant =
builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({0.f, 1.f, 2.f})));
+ LiteralUtil::CreateR1<float>({0.f, 1.f, 2.f})));
builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide,
param0, constant));
@@ -559,7 +559,7 @@ TEST_F(AlgebraicSimplifierTest, DivOneScalar) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32, "param0"));
HloInstruction* one = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
HloInstruction* div = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, one));
@@ -580,7 +580,7 @@ TEST_F(AlgebraicSimplifierTest, DivOneArray) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r2f32, "param0"));
HloInstruction* one = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1.0, 1.0}, {1.0, 1.0}})));
+ LiteralUtil::CreateR2<float>({{1.0, 1.0}, {1.0, 1.0}})));
HloInstruction* div = builder.AddInstruction(
HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, one));
@@ -860,7 +860,7 @@ TEST_F(AlgebraicSimplifierTest, Pow0Scalar) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32, "param0"));
HloInstruction* zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, zero));
@@ -884,7 +884,7 @@ TEST_F(AlgebraicSimplifierTest, Pow0Vector) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r1f32, "param0"));
HloInstruction* zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
builder.AddInstruction(
HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param0, zero));
@@ -912,7 +912,7 @@ TEST_F(AlgebraicSimplifierTest, Pow1) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32, "param0"));
HloInstruction* one = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1)));
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, one));
@@ -934,7 +934,7 @@ TEST_F(AlgebraicSimplifierTest, Pow2) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32, "param0"));
HloInstruction* two = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2)));
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, two));
@@ -956,7 +956,7 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32, "param0"));
HloInstruction* negative_one = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(-1)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(-1)));
builder.AddInstruction(HloInstruction::CreateBinary(r0f32, HloOpcode::kPower,
param0, negative_one));
@@ -1047,7 +1047,7 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedReduceWindow) {
builder.AddInstruction(HloInstruction::CreateReduceWindow(
ShapeUtil::MakeShape(F32, {5, 2}), param,
builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f))),
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f))),
window, add_computation));
module().AddEntryComputation(builder.Build());
HloPassFix<AlgebraicSimplifier> simplifier(/*is_layout_sensitive=*/false,
@@ -1074,7 +1074,7 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedPad) {
builder.AddInstruction(HloInstruction::CreatePad(
ShapeUtil::MakeShape(F32, {5, 2}), param,
builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(0.0f))),
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))),
padding));
module().AddEntryComputation(builder.Build());
EXPECT_THAT(module().entry_computation()->root_instruction(),
@@ -1116,7 +1116,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) {
TEST_F(AlgebraicSimplifierTest, ConvertBetweenSameType) {
HloComputation::Builder builder(TestName());
HloInstruction* input = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
builder.AddInstruction(
HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input));
@@ -1208,7 +1208,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) {
HloInstruction* param1 = builder.AddInstruction(
HloInstruction::CreateParameter(1, r1f32, "param1"));
HloInstruction* empty_literal = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>({})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({})));
HloInstruction* empty_slice =
builder.AddInstruction(HloInstruction::CreateSlice(
ShapeUtil::MakeShape(F32, {0}), param1, {42}, {42}, {1}));
@@ -1238,7 +1238,7 @@ TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r1f32, "param0"));
HloInstruction* empty_literal = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>({})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({})));
HloInstruction* empty_slice =
builder.AddInstruction(HloInstruction::CreateSlice(
ShapeUtil::MakeShape(F32, {0}), param0, {42}, {42}, {1}));
@@ -1420,7 +1420,7 @@ TEST_F(AlgebraicSimplifierTest, FailureToSinkReshapeDoesntAffectChangedBit) {
builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "param0")),
builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{0, 0}, {0, 0}})))));
+ LiteralUtil::CreateR2<float>({{0, 0}, {0, 0}})))));
builder.AddInstruction(
HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {4}), add));
@@ -1443,7 +1443,7 @@ TEST_F(AlgebraicSimplifierTest, FailureToSinkBroadcastDoesntAffectChangedBit) {
builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "param0")),
builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{0, 0}, {0, 0}})))));
+ LiteralUtil::CreateR2<float>({{0, 0}, {0, 0}})))));
builder.AddInstruction(
HloInstruction::CreateBroadcast(ShapeUtil::MakeShape(F32, {2, 2, 2}), add,
@@ -1726,7 +1726,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) {
builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {2, 2}), "param"));
HloInstruction* zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
PaddingConfig no_padding;
for (int i = 0; i < 2; ++i) {
auto dimension = no_padding.add_dimensions();
@@ -1757,7 +1757,7 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) {
builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {10, 10}), "param"));
HloInstruction* zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
PaddingConfig padding;
int64 low_padding[2] = {-1, -2};
int64 high_padding[2] = {2, -3};
@@ -2109,7 +2109,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) {
TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) {
HloComputation::Builder builder(TestName());
HloInstruction* forty_two = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6});
HloInstruction* broadcast = builder.AddInstruction(
@@ -2156,7 +2156,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
padding.mutable_dimensions(3)->set_edge_padding_high(2);
HloInstruction* pad_value = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(5.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(5.0f)));
HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
ShapeUtil::MakeShape(F32, {1, 3, 3, 5}), operand, pad_value, padding));
@@ -2187,7 +2187,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
const Shape reduce_window_shape =
ShapeUtil::MakeShape(F32, {111, 113, 113, 115});
HloInstruction* reduce_init_value = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(5.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(5.0f)));
HloInstruction* reduce_window =
builder.AddInstruction(HloInstruction::CreateReduceWindow(
reduce_window_shape, pad, reduce_init_value, window,
@@ -2238,7 +2238,7 @@ TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) {
padding.mutable_dimensions(3)->set_edge_padding_high(2);
HloInstruction* pad_value = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(5.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(5.0f)));
HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
ShapeUtil::MakeShape(BF16, {1, 3, 3, 5}), parameter, pad_value, padding));
@@ -2273,7 +2273,7 @@ TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) {
const Shape reduce_window_shape =
ShapeUtil::MakeShape(F32, {111, 113, 113, 115});
HloInstruction* reduce_init_value = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(5.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(5.0f)));
HloInstruction* reduce_window =
builder.AddInstruction(HloInstruction::CreateReduceWindow(
reduce_window_shape, convert, reduce_init_value, window,
@@ -2344,9 +2344,9 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) {
HloComputation::Builder call_builder(TestName() + ".Call");
HloInstruction* zero = call_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>({0.0f})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({0.0f})));
HloInstruction* one = call_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>({1.0f})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({1.0f})));
call_builder.AddInstruction(
HloInstruction::CreateCall(r1f32, {zero, one}, dot_computation.get()));
@@ -2362,9 +2362,9 @@ TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) {
HloComputation::Builder builder(TestName());
const float constant_scalar = 7.3f;
std::initializer_list<float> constant_vector = {1.1f, 2.0f, 3.3f};
- std::unique_ptr<Literal> value =
- Literal::MakeTuple({Literal::CreateR0<float>(constant_scalar).get(),
- Literal::CreateR1<float>(constant_vector).get()});
+ std::unique_ptr<Literal> value = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR0<float>(constant_scalar).get(),
+ LiteralUtil::CreateR1<float>(constant_vector).get()});
builder.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
auto computation = module().AddEntryComputation(builder.Build());
@@ -2387,8 +2387,8 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicSlice) {
shape,
builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "slice_from")),
- builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int>({0, 0, 0}))),
+ builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<int>({0, 0, 0}))),
/*slice_sizes=*/{10, 100, 1000}));
auto computation = module().AddEntryComputation(builder.Build());
@@ -2421,8 +2421,8 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) {
builder.AddInstruction(
HloInstruction::CreateParameter(2, slice_shape, "to_update")),
slice,
- builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int>({0, 0, 0})))));
+ builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<int>({0, 0, 0})))));
auto computation = module().AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
@@ -2437,7 +2437,7 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcasts) {
HloComputation::Builder builder(TestName());
Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
HloInstruction* input_array = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>({3, 4})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({3, 4})));
HloInstruction* inner_bcast = builder.AddInstruction(
HloInstruction::CreateBroadcast(r2f32, input_array, {1}));
Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 2, 2});
@@ -2546,7 +2546,7 @@ TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) {
HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
pad_shape, input,
builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(0.0f))),
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))),
padding));
HloComputation* add_computation = nullptr;
@@ -2565,7 +2565,7 @@ TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) {
Window window = window_util::MakeWindow(
decorate_spatials(param.reduce_window_spatials, 1, 1));
auto zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
TF_ASSERT_OK_AND_ASSIGN(const Shape output_shape,
ShapeInference::InferReduceWindowShape(
pad->shape(), zero->shape(), window,
@@ -2704,7 +2704,7 @@ TEST_P(DotOfConcatSimplificationTest, ConstantLHS) {
Shape lhs_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.k});
auto* lhs = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
/*from=*/10.0, /*to=*/10000.0, /*rows=*/spec.m, /*cols=*/spec.k)));
Shape rhs0_shape = ShapeUtil::MakeShape(F32, {k0, spec.n});
@@ -2783,7 +2783,7 @@ TEST_P(DotOfConcatSimplificationTest, ConstantRHS) {
Shape rhs_shape = ShapeUtil::MakeShape(F32, {spec.k, spec.n});
auto* rhs = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
/*from=*/10.0, /*to=*/10000.0, /*rows=*/spec.k, /*cols=*/spec.n)));
DotDimensionNumbers dot_dnums;
@@ -2830,7 +2830,7 @@ TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceZeroUpdate) {
HloInstruction* const update = builder.AddInstruction(
HloInstruction::CreateParameter(1, update_shape, "update"));
HloInstruction* const start_indices = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int>({0})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int>({0})));
builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
dslice_shape, operand, update, start_indices));
const HloComputation* const computation =
@@ -2879,7 +2879,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) {
int64 lhs_cols = (spec.lcd == 0) ? spec.m : (spec.k + k_increase);
Shape lhs_shape = ShapeUtil::MakeShape(F32, {lhs_rows, lhs_cols});
auto* lhs = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
/*from=*/10.0, /*to=*/10000.0, /*rows=*/lhs_rows,
/*cols=*/lhs_cols)));
@@ -2887,7 +2887,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) {
int32 start_col = (spec.lcd == 0) ? spec.s : 0;
const auto start_indices =
builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<int32>({start_row, start_col})));
+ LiteralUtil::CreateR1<int32>({start_row, start_col})));
int64 slice_row_size = (spec.lcd == 0) ? spec.k : 1;
int64 slice_col_size = (spec.lcd == 0) ? 1 : spec.k;
Shape ds_shape = ShapeUtil::MakeShape(F32, {slice_row_size, slice_col_size});
@@ -2898,7 +2898,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) {
int64 rhs_cols = (spec.rcd == 0) ? spec.n : spec.k;
Shape rhs_shape = ShapeUtil::MakeShape(F32, {rhs_rows, rhs_cols});
auto* rhs = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
/*from=*/10.0, /*to=*/10000.0, /*rows=*/rhs_rows,
/*cols=*/rhs_cols)));
@@ -2946,7 +2946,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) {
int64 lhs_cols = (spec.lcd == 0) ? spec.m : spec.k;
Shape lhs_shape = ShapeUtil::MakeShape(F32, {lhs_rows, lhs_cols});
auto* lhs = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
/*from=*/10.0, /*to=*/10000.0, /*rows=*/lhs_rows,
/*cols=*/lhs_cols)));
@@ -2957,7 +2957,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) {
int64 rhs_cols = (spec.rcd == 0) ? spec.n : (spec.k + k_increase);
Shape rhs_shape = ShapeUtil::MakeShape(F32, {rhs_rows, rhs_cols});
auto* rhs = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
/*from=*/10.0, /*to=*/10000.0, /*rows=*/rhs_rows,
/*cols=*/rhs_cols)));
@@ -2965,7 +2965,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) {
int32 start_col = (spec.rcd == 0) ? spec.s : 0;
const auto start_indices =
builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<int32>({start_row, start_col})));
+ LiteralUtil::CreateR1<int32>({start_row, start_col})));
int64 slice_row_size = (spec.rcd == 0) ? spec.k : 1;
int64 slice_col_size = (spec.rcd == 0) ? 1 : spec.k;
Shape ds_shape = ShapeUtil::MakeShape(F32, {slice_row_size, slice_col_size});
diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc
index ec13fadbc7..c4cd60c120 100644
--- a/tensorflow/compiler/xla/service/batchnorm_expander.cc
+++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -34,6 +35,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
+#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -41,6 +43,8 @@ namespace xla {
namespace {
+using tensorflow::gtl::optional;
+
// BatchNormExpanderVisitor traverses the HLO computation and rewrites BatchNorm
// operations into smaller operations.
class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
@@ -97,7 +101,7 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
add_instruction(HloInstruction::CreateConvert(
ShapeUtil::MakeShape(operand->shape().element_type(), {}),
add_instruction(HloInstruction::CreateConstant(
- Literal::CreateR0<float>(-0.5f))))),
+ LiteralUtil::CreateR0<float>(-0.5f))))),
{}));
return HloInstruction::CreateBinary(operand->shape(), HloOpcode::kPower,
operand, exponent);
@@ -113,7 +117,7 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
add_instruction(HloInstruction::CreateConvert(
ShapeUtil::MakeShape(operand->shape().element_type(), {}),
add_instruction(HloInstruction::CreateConstant(
- Literal::CreateR0<float>(1.0 / element_count))))),
+ LiteralUtil::CreateR0<float>(1.0 / element_count))))),
{}));
return HloInstruction::CreateBinary(operand->shape(), HloOpcode::kMultiply,
operand, elem_count_recip);
@@ -200,11 +204,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining(
HloInstruction* offset = batch_norm->mutable_operand(2);
const Shape feature_shape = scale->shape();
- auto zero_literal = Literal::CreateR0(0.0f);
+ auto zero_literal = LiteralUtil::CreateR0(0.0f);
TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype));
auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal)));
- auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon());
+ auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
auto epsilon = add(HloInstruction::CreateBroadcast(
operand_shape,
@@ -288,16 +292,22 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining(
int64 instruction_count_after = computation_->instruction_count();
CHECK_EQ(instruction_count_after,
instruction_count_before + added_instructions.size());
+ const HloSharding& sharding = batch_norm->sharding();
HloSharding operand_sharding =
- batch_norm->sharding().GetAsShapeTree(batch_norm->shape()).element({0});
+ sharding.GetAsShapeTree(batch_norm->shape()).element({0});
+ optional<int64> unique_device = batch_norm->sharding_unique_device();
+ HloSharding default_sharding =
+ unique_device.has_value()
+ ? HloSharding::AssignDevice(unique_device.value())
+ : HloSharding::Replicate();
for (HloInstruction* inst : added_instructions) {
if (ShapeUtil::Equal(inst->shape(), operand_shape)) {
inst->set_sharding(operand_sharding);
} else {
- inst->set_sharding(HloSharding::Replicate());
+ inst->set_sharding(default_sharding);
}
}
- tuple->set_sharding(batch_norm->sharding());
+ tuple->set_sharding(sharding);
}
TF_CHECK_OK(ReplaceWithNewInstruction(batch_norm, std::move(tuple)));
return Status::OK();
@@ -320,7 +330,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference(
HloInstruction* var = batch_norm->mutable_operand(4);
const Shape feature_shape = scale->shape();
- auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon());
+ auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
auto epsilon = computation_->AddInstruction(HloInstruction::CreateBroadcast(
operand_shape,
@@ -388,14 +398,20 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference(
CHECK_EQ(instruction_count_after,
instruction_count_before + added_instructions.size());
if (batch_norm->has_sharding()) {
+ const HloSharding& sharding = batch_norm->sharding();
+ optional<int64> unique_device = batch_norm->sharding_unique_device();
+ HloSharding default_sharding =
+ unique_device.has_value()
+ ? HloSharding::AssignDevice(unique_device.value())
+ : HloSharding::Replicate();
for (HloInstruction* inst : added_instructions) {
if (ShapeUtil::Equal(inst->shape(), operand_shape)) {
- inst->set_sharding(batch_norm->sharding());
+ inst->set_sharding(sharding);
} else {
- inst->set_sharding(HloSharding::Replicate());
+ inst->set_sharding(default_sharding);
}
}
- shifted_normalized->set_sharding(batch_norm->sharding());
+ shifted_normalized->set_sharding(sharding);
}
TF_CHECK_OK(
ReplaceWithNewInstruction(batch_norm, std::move(shifted_normalized)));
@@ -447,11 +463,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad(
const int64 feature_count = activation_shape.dimensions(feature_index);
const int64 elements_per_feature_int64 = size_in_elements / feature_count;
- auto zero_literal = Literal::CreateR0(0.0f);
+ auto zero_literal = LiteralUtil::CreateR0(0.0f);
TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype));
auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal)));
- auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon());
+ auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
auto epsilon_scalar =
add(HloInstruction::CreateConstant(std::move(epsilon_literal)));
@@ -542,7 +558,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad(
Mean(elements_per_feature_int64, scale_times_rsqrt_var_add_epsilon, add));
auto elements_per_feature_literal =
- Literal::CreateR0<float>(elements_per_feature_int64);
+ LiteralUtil::CreateR0<float>(elements_per_feature_int64);
TF_ASSIGN_OR_RETURN(elements_per_feature_literal,
elements_per_feature_literal->Convert(ptype));
auto elements_per_feature = add(
@@ -562,19 +578,25 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad(
auto tuple =
HloInstruction::CreateTuple({grad_activation, grad_scale, grad_beta});
if (batch_norm->has_sharding()) {
+ const HloSharding& sharding = batch_norm->sharding();
int64 instruction_count_after = computation_->instruction_count();
CHECK_EQ(instruction_count_after,
instruction_count_before + added_instructions.size());
HloSharding activation_sharding =
- batch_norm->sharding().GetAsShapeTree(batch_norm->shape()).element({0});
+ sharding.GetAsShapeTree(batch_norm->shape()).element({0});
+ auto unique_device = batch_norm->sharding_unique_device();
+ HloSharding default_sharding =
+ unique_device.has_value()
+ ? HloSharding::AssignDevice(unique_device.value())
+ : HloSharding::Replicate();
for (HloInstruction* inst : added_instructions) {
if (ShapeUtil::Equal(inst->shape(), activation_shape)) {
inst->set_sharding(activation_sharding);
} else {
- inst->set_sharding(HloSharding::Replicate());
+ inst->set_sharding(default_sharding);
}
}
- tuple->set_sharding(batch_norm->sharding());
+ tuple->set_sharding(sharding);
}
TF_CHECK_OK(ReplaceWithNewInstruction(batch_norm, std::move(tuple)));
diff --git a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc
index aa36e64b07..32f785a70a 100644
--- a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc
+++ b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc
@@ -19,12 +19,13 @@ limitations under the License.
#include <utility>
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
@@ -114,5 +115,33 @@ TEST_F(BatchNormExpanderTest, BatchNormGrad) {
EXPECT_EQ(root->opcode(), HloOpcode::kTuple);
}
+TEST_F(BatchNormExpanderTest, BatchNormTrainingSharding) {
+ const char* module_str = R"(
+HloModule module
+ENTRY entry {
+ %param.0 = f32[8,4] parameter(0)
+ %param.1 = f32[4] parameter(1)
+ %param.2 = f32[4] parameter(2)
+ ROOT %batch-norm-training = (f32[8,4], f32[4], f32[4])
+ batch-norm-training(f32[8,4] %param.0, f32[4] %param.1, f32[4] %param.2),
+ epsilon=0.001, feature_index=1, sharding={maximal device=1}
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(module_str));
+ BatchNormExpander rewriter(/*rewrite_training_op=*/true,
+ /*rewrite_inference_op=*/true,
+ /*rewrite_grad_op=*/true);
+ ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie());
+
+ for (auto* instruction : module->entry_computation()->instructions()) {
+ if (instruction->opcode() == HloOpcode::kParameter) {
+ continue;
+ }
+ ASSERT_TRUE(instruction->has_sharding());
+ TF_ASSERT_OK_AND_ASSIGN(int device, instruction->sharding().UniqueDevice());
+ EXPECT_EQ(device, 1);
+ }
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc
index ff6d5027ef..b21c83a07f 100644
--- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/bfloat16_propagation.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_dce.h"
@@ -615,7 +615,6 @@ Status BFloat16Propagation::ResolveInconsistentFusions(HloModule* module) {
// (1) a is F32 but tuple is BF16
// (2) after adding conversion
// (3) after tuple simplifier and DCE.
- bool needs_tuple_simplifier = false;
for (auto computation : module->MakeComputationPostOrder()) {
auto insts = computation->MakeInstructionPostOrder();
for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
@@ -629,67 +628,25 @@ Status BFloat16Propagation::ResolveInconsistentFusions(HloModule* module) {
continue;
}
ShapeTree<HloInstruction*> converted_outputs(hlo->shape());
- // Iterate through nodes in the shape tree in pre-order and initialize
- // each non-root node with a corresponding get-tuple-element. For a leaf
- // node, if its shape does not match the fusion output, create a
- // conversion node to overwrite the node value.
- for (auto it = converted_outputs.begin(); it != converted_outputs.end();
- ++it) {
- ShapeIndex output_index = it->first;
- HloInstruction*& output = it->second;
- const Shape subshape =
- ShapeUtil::GetSubshape(hlo->shape(), output_index);
- if (output_index.empty()) {
- output = fusion_root;
- } else {
- ShapeIndex parent_index = output_index;
- parent_index.pop_back();
- output = fusion_computation->AddInstruction(
- HloInstruction::CreateGetTupleElement(
- subshape, converted_outputs.element(parent_index),
- output_index.back()));
- }
- if (!ShapeUtil::IsArray(subshape)) {
- continue;
- }
- if (!ShapeUtil::Compatible(
- subshape,
- ShapeUtil::GetSubshape(fusion_root->shape(), output_index))) {
- output = fusion_computation->AddInstruction(
- HloInstruction::CreateConvert(subshape, output));
- }
- }
- // Iterate through nodes in the shape tree in reverse pre-order and create
- // a tuple instruction for each non-leaf node where the elements are the
- // values of its child nodes.
- for (auto it = converted_outputs.rbegin(); it != converted_outputs.rend();
- ++it) {
- ShapeIndex output_index = it->first;
- HloInstruction*& output = it->second;
- const Shape& subshape =
- ShapeUtil::GetSubshape(hlo->shape(), output_index);
- if (!ShapeUtil::IsTuple(subshape)) {
- continue;
- }
- std::vector<HloInstruction*> elements(
- ShapeUtil::TupleElementCount(subshape));
- ShapeIndex child_index = output_index;
- for (int64 i = 0; i < elements.size(); ++i) {
- child_index.push_back(i);
- elements[i] = converted_outputs.element(child_index);
- child_index.pop_back();
- }
- output = fusion_computation->AddInstruction(
- HloInstruction::CreateTuple(elements));
- }
- fusion_computation->set_root_instruction(converted_outputs.element({}));
- needs_tuple_simplifier |= ShapeUtil::IsTuple(hlo->shape());
+ // Deep copy the fusion root, and convert a leaf node only if its shape
+ // does not match the fusion output.
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * copy,
+ fusion_computation->DeepCopyInstructionWithCustomCopier(
+ fusion_root,
+ [hlo](HloInstruction* leaf, const ShapeIndex& leaf_index,
+ HloComputation* comp) {
+ const Shape& hlo_subshape =
+ ShapeUtil::GetSubshape(hlo->shape(), leaf_index);
+ if (ShapeUtil::Compatible(leaf->shape(), hlo_subshape)) {
+ return leaf;
+ }
+ return comp->AddInstruction(
+ HloInstruction::CreateConvert(hlo_subshape, leaf));
+ }));
+ fusion_computation->set_root_instruction(copy);
}
}
- if (needs_tuple_simplifier) {
- TupleSimplifier tuple_simplifier;
- TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
- }
return Status::OK();
}
@@ -758,10 +715,38 @@ StatusOr<bool> BFloat16Propagation::Run(HloModule* module) {
changes_to_bf16_.clear();
changed_ = false;
+ auto computations_topological_order = module->MakeComputationPostOrder();
+
+ // Before running the propagation pass, we insert copies (kConvert to the same
+ // type) of F32 inputs to while loops. This prevents other uses of the same
+ // input from aliasing the while loop input/output, so that there's greater
+ // chance to use BF16 inside the loop. If some of these added copies do not
+ // help, they will remain F32 after BF16 propagation and will be removed since
+ // they are no-ops.
+ for (auto computation : computations_topological_order) {
+ for (auto inst : computation->MakeInstructionPostOrder()) {
+ if (inst->opcode() != HloOpcode::kWhile) {
+ continue;
+ }
+
+ auto operand = inst->mutable_operand(0);
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * copy,
+ computation->DeepCopyInstructionWithCustomCopier(
+ operand, [](HloInstruction* leaf, const ShapeIndex& leaf_index,
+ HloComputation* comp) {
+ if (leaf->shape().element_type() != F32) {
+ return leaf;
+ }
+ return comp->AddInstruction(
+ HloInstruction::CreateConvert(leaf->shape(), leaf));
+ }));
+ TF_RETURN_IF_ERROR(operand->ReplaceUseWith(inst, copy));
+ }
+ }
+
TF_ASSIGN_OR_RETURN(dataflow_, HloDataflowAnalysis::Run(*module));
- const auto& computations_topological_order =
- module->MakeComputationPostOrder();
// The first step is a forward pass (parameters to root), where we determine
// the potential candidate instructions to use bfloat16 in the outputs that
// are not likely to cause overhead from extra explicit conversions. This is
@@ -810,23 +795,27 @@ StatusOr<bool> BFloat16Propagation::Run(HloModule* module) {
}
}
+ // Removes redundant HLOs added by this pass, either when inserting
+ // de-aliasing copies to while loop inputs, or later when converting output
+ // types.
+ auto clean_up = [this, module]() {
+ TF_RETURN_IF_ERROR(SkipNoopConversions(module));
+ TupleSimplifier tuple_simplifier;
+ TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
+ HloDCE dce;
+ TF_RETURN_IF_ERROR(dce.Run(module).status());
+ return Status::OK();
+ };
+
if (!changed_) {
+ TF_RETURN_IF_ERROR(clean_up());
return false;
}
TF_RETURN_IF_ERROR(ResolveInconsistentFusions(module));
TF_RETURN_IF_ERROR(ResolveConvertedConstants(module));
- // This pass could have turned an F32 -> BF16 conversion to a no-op (BF16 ->
- // BF16), so we skip them now.
- TF_RETURN_IF_ERROR(SkipNoopConversions(module));
-
- {
- // We may have dead HLOs after ResolveInconsistentFusions,
- // ResolveConvertedConstants and SkipNoopConversions.
- HloDCE dce;
- TF_RETURN_IF_ERROR(dce.Run(module).status());
- }
+ TF_RETURN_IF_ERROR(clean_up());
return true;
}
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
index 2124b302cc..aeafb25ad7 100644
--- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
@@ -133,9 +133,9 @@ TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) {
array_b.FillUnique(10.0f);
HloInstruction* a = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateFromArray(array_a)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateFromArray(array_a)));
HloInstruction* b = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateFromArray(array_b)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateFromArray(array_b)));
HloInstruction* dot = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kDot, a, b));
@@ -150,10 +150,10 @@ TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) {
EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConstant);
EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConstant);
EXPECT_TRUE(LiteralTestUtil::Equal(
- *Literal::ConvertF32ToBF16(*Literal::CreateFromArray(array_a)),
+ *LiteralUtil::ConvertF32ToBF16(*LiteralUtil::CreateFromArray(array_a)),
dot->operand(0)->literal()));
EXPECT_TRUE(LiteralTestUtil::Equal(
- *Literal::ConvertF32ToBF16(*Literal::CreateFromArray(array_b)),
+ *LiteralUtil::ConvertF32ToBF16(*LiteralUtil::CreateFromArray(array_b)),
dot->operand(1)->literal()));
}
@@ -240,12 +240,10 @@ TEST_F(BFloat16PropagationTest, SameValueReferencedTwice) {
EXPECT_TRUE(PropagatePrecision(module.get()));
EXPECT_EQ(computation->root_instruction(), dot);
- EXPECT_TRUE(OutputsBF16(add0));
EXPECT_TRUE(OutputsBF16(add1));
EXPECT_TRUE(OutputsBF16(lhs));
- // rhs is a get-tuple-element, which does not define a buffer, but its shape
- // should also be adjusted accordingly.
- EXPECT_TRUE(OutputsBF16(rhs));
+
+ // add0 and rhs have been eliminated by simplification and DCE.
}
// Tests that a non-fusion computation's root should not be changed.
@@ -734,10 +732,8 @@ TEST_F(BFloat16PropagationTest, NoopConversionRemoved) {
EXPECT_TRUE(PropagatePrecision(module.get()));
EXPECT_EQ(computation->root_instruction(), add2);
- EXPECT_EQ(add2->operand(0), gte0);
- EXPECT_EQ(add2->operand(1), gte1);
- EXPECT_EQ(gte0->shape().element_type(), BF16);
- EXPECT_EQ(gte1->shape().element_type(), BF16);
+ EXPECT_EQ(add2->operand(0), add0);
+ EXPECT_EQ(add2->operand(1), add1);
EXPECT_EQ(add0->shape().element_type(), BF16);
EXPECT_EQ(add1->shape().element_type(), BF16);
}
diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
index 6958ee722a..125ade2a11 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include <utility>
#include <vector>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
#include "tensorflow/compiler/xla/service/call_graph.h"
@@ -125,7 +125,7 @@ class BufferAssignmentTest : public HloTestBase {
auto param =
builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
auto value = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param, value));
return builder.Build();
@@ -142,7 +142,7 @@ class BufferAssignmentTest : public HloTestBase {
const string& name) {
auto builder = HloComputation::Builder(name);
auto const4 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int>(4)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(4)));
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, t_s32_f32v4_, "x"));
auto index = builder.AddInstruction(
@@ -167,9 +167,9 @@ class BufferAssignmentTest : public HloTestBase {
const string& name) {
auto builder = HloComputation::Builder(name);
auto const1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int>(1)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(1)));
auto constv = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
+ LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, t_s32_f32v4_, "x"));
auto indexc = builder.AddInstruction(
@@ -290,7 +290,7 @@ static bool BuffersDistinct(const std::vector<const HloInstruction*>& a,
TEST_F(BufferAssignmentTest, ScalarConstant) {
auto builder = HloComputation::Builder(TestName());
auto const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
@@ -304,9 +304,9 @@ TEST_F(BufferAssignmentTest, BufferForConst) {
// no buffers assigned, and their consumer has a buffer.
auto builder = HloComputation::Builder(TestName());
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
+ LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({4.1f, 4.2f, 4.3f, 4.4f})));
+ LiteralUtil::CreateR1<float>({4.1f, 4.2f, 4.3f, 4.4f})));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, const0, const1));
auto module = CreateNewModule();
@@ -327,7 +327,7 @@ TEST_F(BufferAssignmentTest, HasAllocationAt) {
auto param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, f32vec100_, "param0"));
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int>(1)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(1)));
auto negate = builder.AddInstruction(
HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
auto tuple = builder.AddInstruction(
@@ -352,7 +352,7 @@ TEST_F(BufferAssignmentTest, BufferForOutputConst) {
// This computation copies a constant to output.
auto builder = HloComputation::Builder(TestName());
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
+ LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
auto copy = builder.AddInstruction(
HloInstruction::CreateUnary(const0->shape(), HloOpcode::kCopy, const0));
auto module = CreateNewModule();
@@ -660,7 +660,7 @@ TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) {
auto exp2 = builder.AddInstruction(
HloInstruction::CreateUnary(f32a100x10_, HloOpcode::kExp, exp1));
auto const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
auto reduce = builder.AddInstruction(HloInstruction::CreateReduce(
/*shape=*/f32vec10_,
/*operand=*/exp2,
@@ -708,9 +708,9 @@ TEST_F(BufferAssignmentTest, ExampleWhile) {
// Creates the main kernel and verifies instruction counts.
auto builder = HloComputation::Builder(TestName());
auto const3 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int>(0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(0)));
auto const4 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
+ LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
auto tuple =
builder.AddInstruction(HloInstruction::CreateTuple({const3, const4}));
auto while_op = builder.AddInstruction(HloInstruction::CreateWhile(
@@ -773,11 +773,11 @@ TEST_F(BufferAssignmentTest, ExampleConditional) {
auto builder = HloComputation::Builder(TestName());
auto pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
auto const1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(56.4f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(56.4f)));
auto const2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(12.4f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(12.4f)));
auto conditional = builder.AddInstruction(HloInstruction::CreateConditional(
r0f32_, pred, const1, true_computation, const2, false_computation));
module->AddEntryComputation(builder.Build());
@@ -1200,8 +1200,9 @@ TEST_F(BufferAssignmentTest, DISABLED_TupleConstantAsOutput) {
// Test that a tuple constant which is forwarded to the computation output
// is properly handled.
auto builder = HloComputation::Builder(TestName());
- builder.AddInstruction(HloInstruction::CreateConstant(Literal::MakeTuple(
- {Literal::CreateR0<int64>(0).get(), Literal::CreateR0<int64>(1).get()})));
+ builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(0).get(),
+ LiteralUtil::CreateR0<int64>(1).get()})));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
@@ -1584,7 +1585,7 @@ TEST_F(BufferAssignmentTest, PeakBuffersWhile) {
auto b = HloComputation::Builder(TestName() + ".cond");
b.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
b.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
condition = module->AddEmbeddedComputation(b.Build());
}
HloComputation* body;
@@ -1647,9 +1648,9 @@ class WhileBufferAssignmentTest : public HloTestBase {
builder.AddInstruction(
HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
auto zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int>(0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(0)));
auto ten = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int>(10)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(10)));
builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, zero, ten));
return builder.Build();
@@ -1708,7 +1709,7 @@ TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) {
HloInstruction::CreateParameter(2, data_shape_, "weights1"));
auto zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
auto output0 = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
auto output1 = builder.AddInstruction(
@@ -1851,7 +1852,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) {
auto build_cond = [&]() {
auto builder = HloComputation::Builder("cond");
auto const4 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int>(4)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(4)));
auto param =
builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32, "x"));
builder.AddInstruction(HloInstruction::CreateBinary(
@@ -1863,7 +1864,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) {
auto build_body = [&]() {
auto builder = HloComputation::Builder("body");
auto const9 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int>(9)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(9)));
auto param =
builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32, "x"));
builder.AddInstruction(
@@ -1875,7 +1876,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) {
auto module = CreateNewModule();
auto builder = HloComputation::Builder("entry");
- auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({}));
+ auto token = builder.AddInstruction(HloInstruction::CreateToken());
auto infeed =
builder.AddInstruction(HloInstruction::CreateInfeed(r0s32, token, ""));
auto infeed_data = builder.AddInstruction(
@@ -1891,7 +1892,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) {
HloInstruction::CreateWhile(r0s32, cond1, body1, while0));
auto zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(r0s32, HloOpcode::kAdd, zero, zero));
auto cond2 = module->AddEmbeddedComputation(build_cond());
@@ -1953,7 +1954,7 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) {
HloInstruction::CreateParameter(1, data_shape_, "weights0"));
auto zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
auto output0 = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
@@ -1997,16 +1998,16 @@ TEST_F(BufferAssignmentTest, TwoCalls) {
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32, "param"));
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param, constant1));
sub_computation = module->AddEmbeddedComputation(builder.Build(add));
}
auto builder = HloComputation::Builder(TestName());
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto constant3 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
auto call1 = builder.AddInstruction(
HloInstruction::CreateCall(r0f32, {constant2}, sub_computation));
auto call2 = builder.AddInstruction(
@@ -2058,9 +2059,9 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
auto builder = HloComputation::Builder(TestName());
auto zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
auto one = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto input0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, data_shape_, "input0"));
@@ -2142,7 +2143,7 @@ TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) {
HloInstruction::CreateParameter(1, data_shape_, "weights0"));
auto zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
auto output0 = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
auto output1 = builder.AddInstruction(
diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc
index 7833ebe73b..4a927b5767 100644
--- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc
@@ -327,7 +327,7 @@ TEST_F(BufferLivenessTest, RootInstructionIsNotLastInSequentialOrder) {
builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, param, param));
- auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({}));
+ auto token = builder.AddInstruction(HloInstruction::CreateToken());
auto recv = builder.AddInstruction(
HloInstruction::CreateRecv(vec_, token, /*channel_id=*/0));
auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv));
@@ -439,11 +439,13 @@ TEST_F(BufferLivenessTest, TupleConstantLiveOut) {
// computation. The buffer containing {0, 1} is copied by GetTupleElement, and
// the buffers containing {3} and 3 are dead.
auto builder = HloComputation::Builder(TestName());
- auto inner_tuple0 = Literal::MakeTuple(
- {Literal::CreateR0<int64>(0).get(), Literal::CreateR0<int64>(1).get()});
- auto inner_tuple1 = Literal::MakeTuple({Literal::CreateR0<int64>(3).get()});
+ auto inner_tuple0 =
+ LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(0).get(),
+ LiteralUtil::CreateR0<int64>(1).get()});
+ auto inner_tuple1 =
+ LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(3).get()});
auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::MakeTuple({inner_tuple0.get(), inner_tuple1.get()})));
+ LiteralUtil::MakeTuple({inner_tuple0.get(), inner_tuple1.get()})));
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
inner_tuple0->shape(), tuple_constant, 0));
@@ -491,7 +493,7 @@ TEST_F(BufferLivenessTest, IndependentTupleElements) {
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
tuple_element0_shape, tuple_param0, 0));
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
+ LiteralUtil::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
tuple_element0_shape, HloOpcode::kAdd, tuple_element0, const0));
@@ -503,7 +505,7 @@ TEST_F(BufferLivenessTest, IndependentTupleElements) {
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
tuple_element1_shape, tuple_param0, 1));
auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f})));
+ LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f})));
auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
tuple_element1_shape, HloOpcode::kAdd, tuple_element1, const1));
@@ -555,7 +557,7 @@ TEST_F(BufferLivenessTest, DependentTupleElements) {
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
tuple_element0_shape, tuple_param0, 0));
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
+ LiteralUtil::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
tuple_element0_shape, HloOpcode::kAdd, tuple_element0, const0));
@@ -627,7 +629,7 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest {
HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 1));
auto update = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({2.f, 2.f, 2.f})));
+ LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
HloInstruction* slice = nullptr;
if (update_uses_tuple_element1) {
// Create a slice instruction as an additional user of 'gte1'.
@@ -638,7 +640,7 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest {
}
// Create a DynamicUpdateSlice instruction of tuple element 1 with 'update'.
auto starts = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int32>({2})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({2})));
auto dynamic_update_slice =
builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
data_shape, gte1, update, starts));
@@ -757,7 +759,7 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest {
HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 1));
auto update = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({2.f, 2.f, 2.f})));
+ LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
if (tuple_element1_has_two_uses) {
// Add 'gte0' and 'gte1' to create another user of 'gte1'.
@@ -766,7 +768,7 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest {
}
// Create a DynamicUpdateSlice instruction of tuple element 1 with 'update'.
auto starts = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int32>({2})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({2})));
auto dynamic_update_slice =
builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
data_shape, gte1, update, starts));
diff --git a/tensorflow/compiler/xla/service/call_graph_test.cc b/tensorflow/compiler/xla/service/call_graph_test.cc
index 1ea7d538cd..cc80b74843 100644
--- a/tensorflow/compiler/xla/service/call_graph_test.cc
+++ b/tensorflow/compiler/xla/service/call_graph_test.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/call_graph.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -82,7 +82,7 @@ class CallGraphTest : public HloTestBase {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, kScalarShape, "param0"));
HloInstruction* zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, param0, zero));
return builder.Build();
@@ -247,11 +247,11 @@ TEST_F(CallGraphTest, ComputationWithConditional) {
HloComputation::Builder builder(TestName());
HloInstruction* pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
HloInstruction* const1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(56.4f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(56.4f)));
HloInstruction* const2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(12.6f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(12.6f)));
HloInstruction* conditional =
builder.AddInstruction(HloInstruction::CreateConditional(
kScalarShape, pred, const1, true_computation, const2,
diff --git a/tensorflow/compiler/xla/service/call_inliner.cc b/tensorflow/compiler/xla/service/call_inliner.cc
index 482ccc5b67..256d05a73e 100644
--- a/tensorflow/compiler/xla/service/call_inliner.cc
+++ b/tensorflow/compiler/xla/service/call_inliner.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <deque>
#include "tensorflow/compiler/xla/service/call_graph.h"
+#include "tensorflow/compiler/xla/service/hlo_dce.h"
#include "tensorflow/core/lib/core/errors.h"
namespace xla {
@@ -151,6 +152,14 @@ StatusOr<bool> CallInliner::Run(HloModule* module) {
}
return Status::OK();
}));
+ if (did_mutate) {
+ // Run DCE to remove called computations which are now becoming unused.
+ // This can result then in problems if within the called computation, there
+ // were send/recv instructions, which the module group verifier will flag as
+ // error findingthe same channel ID used for multiple send/recv
+ // instructions.
+ TF_RETURN_IF_ERROR(HloDCE().Run(module).status());
+ }
return did_mutate;
}
diff --git a/tensorflow/compiler/xla/service/call_inliner_test.cc b/tensorflow/compiler/xla/service/call_inliner_test.cc
index 924348c870..ff968bca29 100644
--- a/tensorflow/compiler/xla/service/call_inliner_test.cc
+++ b/tensorflow/compiler/xla/service/call_inliner_test.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include <utility>
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -48,9 +48,9 @@ TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) {
// the "one" value.
HloComputation::Builder inner(TestName() + ".inner");
HloInstruction* zero = inner.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(24.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(24.0f)));
HloInstruction* one = inner.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
TF_ASSERT_OK(zero->AddControlDependencyTo(one));
auto module = CreateNewModule();
HloComputation* inner_computation =
@@ -87,7 +87,7 @@ TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) {
// little trickier.
HloComputation::Builder just_false(TestName() + ".false");
just_false.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
HloComputation* false_computation =
module->AddEmbeddedComputation(just_false.Build());
@@ -99,7 +99,7 @@ TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) {
HloComputation::Builder outer(TestName() + ".outer");
HloInstruction* init_value = outer.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
outer.AddInstruction(
HloInstruction::CreateWhile(pred, call_false, call_false, init_value));
@@ -123,9 +123,9 @@ TEST_F(CallInlinerTest, InlineWithoutRunningPass) {
HloComputation::Builder just_false(TestName() + ".false");
auto* true_constant = just_false.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<bool>({true})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<bool>({true})));
auto* false_constant = just_false.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
TF_ASSERT_OK(false_constant->AddControlDependencyTo(true_constant));
HloComputation* false_computation =
module->AddEmbeddedComputation(just_false.Build());
@@ -147,8 +147,8 @@ TEST_F(CallInlinerTest, CallToOutfeedComputationIsInlined) {
HloComputation::Builder outfeeder(TestName() + ".outfeeder");
auto value = outfeeder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0)));
- auto token = outfeeder.AddInstruction(HloInstruction::CreateAfterAll({}));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+ auto token = outfeeder.AddInstruction(HloInstruction::CreateToken());
outfeeder.AddInstruction(
HloInstruction::CreateOutfeed(f32, value, token, /*outfeed_config=*/""));
diff --git a/tensorflow/compiler/xla/service/computation_placer.cc b/tensorflow/compiler/xla/service/computation_placer.cc
index 7c1bacff92..d26486fcfe 100644
--- a/tensorflow/compiler/xla/service/computation_placer.cc
+++ b/tensorflow/compiler/xla/service/computation_placer.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include <utility>
#include <vector>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status.h"
diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.cc b/tensorflow/compiler/xla/service/conditional_simplifier.cc
index e9ec796121..b7be3ba605 100644
--- a/tensorflow/compiler/xla/service/conditional_simplifier.cc
+++ b/tensorflow/compiler/xla/service/conditional_simplifier.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include <utility>
#include <vector>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/call_inliner.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
diff --git a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc
index 68f6ffc6b7..c43a31b167 100644
--- a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc
@@ -55,7 +55,7 @@ HloComputation* ConditionalSimplifierTest::MakeConditional(HloModule* module) {
true_computation_builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(S32, {}), "param"));
auto one = true_computation_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(1)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
true_computation_builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, param, one));
@@ -73,7 +73,7 @@ HloComputation* ConditionalSimplifierTest::MakeConditional(HloModule* module) {
HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}),
"param"));
auto forty_two = false_computation_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(42)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(42)));
false_computation_builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, param, forty_two));
@@ -82,11 +82,11 @@ HloComputation* ConditionalSimplifierTest::MakeConditional(HloModule* module) {
}
auto false_instrn = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
auto false_param = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(S32, {}), "false_param"));
auto one = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(1)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
builder.AddInstruction(HloInstruction::CreateConditional(
ShapeUtil::MakeShape(S32, {}), false_instrn, one, true_computation,
@@ -106,7 +106,7 @@ TEST_F(ConditionalSimplifierTest, ConditionalWithControlDependency) {
HloComputation* computation = MakeConditional(&module());
auto* true_op = computation->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
TF_ASSERT_OK(
true_op->AddControlDependencyTo(computation->root_instruction()));
@@ -119,11 +119,10 @@ TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsSend) {
ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional);
auto* true_computation = conditional->true_computation();
- auto* token =
- true_computation->AddInstruction(HloInstruction::CreateAfterAll({}));
+ auto* token = true_computation->AddInstruction(HloInstruction::CreateToken());
auto* send = true_computation->AddInstruction(HloInstruction::CreateSend(
true_computation->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(true))),
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true))),
token, /*channel_id=*/0));
true_computation->AddInstruction(HloInstruction::CreateSendDone(send));
EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie());
@@ -135,8 +134,7 @@ TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsRecv) {
ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional);
auto* true_computation = conditional->true_computation();
- auto* token =
- true_computation->AddInstruction(HloInstruction::CreateAfterAll({}));
+ auto* token = true_computation->AddInstruction(HloInstruction::CreateToken());
auto* recv = true_computation->AddInstruction(HloInstruction::CreateRecv(
ShapeUtil::MakeShape(F32, {1}), token, /*channel_id=*/0));
true_computation->AddInstruction(HloInstruction::CreateRecvDone(recv));
@@ -148,8 +146,7 @@ TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsNonRemovableInstruction) {
auto* conditional = computation->root_instruction();
ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional);
auto* false_computation = conditional->false_computation();
- auto token =
- false_computation->AddInstruction(HloInstruction::CreateAfterAll({}));
+ auto token = false_computation->AddInstruction(HloInstruction::CreateToken());
false_computation->AddInstruction(HloInstruction::CreateInfeed(
ShapeUtil::MakeShape(F32, {1}), token, "config"));
EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie());
diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc
index 52e66b3e77..ab3d846403 100644
--- a/tensorflow/compiler/xla/service/copy_insertion.cc
+++ b/tensorflow/compiler/xla/service/copy_insertion.cc
@@ -1092,12 +1092,14 @@ void MaybeDumpModule(const string& message, const HloModule& module) {
} // namespace
-Status CopyInsertion::RemoveUnnecessaryCopies(const HloOrdering& ordering,
- HloModule* module) {
+Status RemoveUnnecessaryCopies(
+ const HloOrdering& ordering, HloModule* module,
+ const HloDataflowAnalysis::FusionCanShareBufferFunction&
+ fusion_can_share_buffer) {
MaybeDumpModule("after adding copies to resolve interference", *module);
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
- HloAliasAnalysis::Run(module, fusion_can_share_buffer_));
+ HloAliasAnalysis::Run(module, fusion_can_share_buffer));
CopyRemover copy_remover(*alias_analysis, ordering, module);
XLA_VLOG_LINES(3, copy_remover.ToString());
diff --git a/tensorflow/compiler/xla/service/copy_insertion.h b/tensorflow/compiler/xla/service/copy_insertion.h
index c5573f76f3..e1973db928 100644
--- a/tensorflow/compiler/xla/service/copy_insertion.h
+++ b/tensorflow/compiler/xla/service/copy_insertion.h
@@ -60,12 +60,6 @@ class CopyInsertion : public HloPassInterface {
// (copies were inserted).
StatusOr<bool> Run(HloModule* module) override;
- // Try to remove as many copies from the module as possible without
- // introducing live range interference. Only copy instructions that are
- // eligible for copy elision are considered for removal.
- Status RemoveUnnecessaryCopies(const HloOrdering& ordering,
- HloModule* module);
-
// The CPU and GPU backend need additional copies added due to deficiencies in
// buffer assignment. Specifically, copies are needed for constants live-out
// of computations, and for values which are live-in and live-out of the same
@@ -83,6 +77,13 @@ class CopyInsertion : public HloPassInterface {
HloDataflowAnalysis::FusionCanShareBufferFunction fusion_can_share_buffer_;
};
+// Try to remove as many copies from the module as possible without introducing
+// live range interference. Only copy instructions that are eligible for
+// copy elision are considered for removal.
+Status RemoveUnnecessaryCopies(
+ const HloOrdering& ordering, HloModule* module,
+ const HloDataflowAnalysis::FusionCanShareBufferFunction&
+ fusion_can_share_buffer = nullptr);
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc
index 105d117cac..cd735256b8 100644
--- a/tensorflow/compiler/xla/service/copy_insertion_test.cc
+++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include <set>
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
@@ -108,7 +108,7 @@ TEST_F(CopyInsertionTest, SingleConstant) {
// be copied before entering the tuple.
auto builder = HloComputation::Builder(TestName());
HloInstruction* constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
HloInstruction* tuple =
builder.AddInstruction(HloInstruction::CreateTuple({constant}));
@@ -132,7 +132,7 @@ TEST_F(CopyInsertionTest, ExistingCopiesNotRemoved) {
auto builder = HloComputation::Builder(TestName());
HloInstruction* constant =
builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{0.f, 2.f}, {2.f, 4.f}})));
+ LiteralUtil::CreateR2<float>({{0.f, 2.f}, {2.f, 4.f}})));
auto minor_to_major = LayoutUtil::MinorToMajor(constant->shape());
Layout reversed_layout =
LayoutUtil::MakeLayoutFromMajorToMinor(minor_to_major);
@@ -167,9 +167,9 @@ TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) {
auto builder = HloComputation::Builder(TestName());
HloInstruction* constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
HloInstruction* constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
HloInstruction* x = builder.AddInstruction(
HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "x"));
@@ -197,11 +197,11 @@ TEST_F(CopyInsertionTest, AmbiguousPointsToSet) {
// the computation result. Verify that copies are added properly.
auto builder = HloComputation::Builder(TestName());
HloInstruction* constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
HloInstruction* constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
HloInstruction* constant3 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
HloInstruction* tuple1 = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
@@ -209,7 +209,7 @@ TEST_F(CopyInsertionTest, AmbiguousPointsToSet) {
HloInstruction::CreateTuple({constant3, constant2}));
HloInstruction* pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
builder.AddInstruction(HloInstruction::CreateTernary(
tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2));
@@ -255,8 +255,9 @@ TEST_F(CopyInsertionTest, BitcastConstant) {
// The output of a bitcast is its operand (same buffer), so a bitcast
// constant feeding the result must have a copy added.
auto builder = HloComputation::Builder(TestName());
- HloInstruction* constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>({1.0, 42.0})));
+ HloInstruction* constant =
+ builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<float>({1.0, 42.0})));
HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary(
ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, constant));
@@ -370,9 +371,9 @@ TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) {
// copy is added.
auto builder = HloComputation::Builder(TestName());
HloInstruction* constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
HloInstruction* constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
HloInstruction* tuple1 = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
@@ -380,7 +381,7 @@ TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) {
HloInstruction::CreateTuple({constant2, constant1}));
HloInstruction* pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
HloInstruction* select = builder.AddInstruction(HloInstruction::CreateTernary(
tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2));
HloInstruction* gte =
@@ -413,7 +414,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
const Shape& loop_state_shape) {
auto builder = HloComputation::Builder(TestName() + ".Condition");
auto limit_const = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(10)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(10)));
auto loop_state = builder.AddInstruction(
HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
auto induction_variable =
@@ -442,7 +443,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
induction_variable_shape_, loop_state, 0));
auto inc = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(1)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc));
// Update data GTE(1).
@@ -480,7 +481,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
induction_variable_shape_, loop_state, 0));
auto inc = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(1)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
// add0 = Add(in0, 1)
auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
@@ -549,7 +550,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
induction_variable_shape_, loop_state, 0));
auto inc = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(1)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
// add0 = Add(in0, 1)
auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc));
@@ -564,8 +565,9 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
data = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
}
- auto update = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
+ auto update = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
+ {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
// add1 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1})
auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
data_shape_, HloOpcode::kAdd, data, update));
@@ -598,7 +600,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
auto gte0 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
induction_variable_shape_, loop_state, 0));
auto inc = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(1)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
gte0->shape(), HloOpcode::kAdd, gte0, inc));
@@ -608,8 +610,9 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
// GTE(GTE(loop_state, 1), 0) -> Add
auto gte10 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(data_shape_, gte1, 0));
- auto update10 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
+ auto update10 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
+ {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
auto add10 = builder.AddInstruction(HloInstruction::CreateBinary(
data_shape_, HloOpcode::kAdd, gte10, update10));
@@ -633,10 +636,11 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
bool nested = false) {
auto builder = HloComputation::Builder(TestName() + ".While");
auto induction_var_init = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
- auto data_init = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f})));
+ auto data_init = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
+ {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f})));
if (nested) {
auto inner_init = builder.AddInstruction(
@@ -659,8 +663,9 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
HloInstruction* BuildWhileInstruction_InitPointsToConstant() {
auto builder = HloComputation::Builder(TestName() + ".While");
- auto data_init = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f})));
+ auto data_init = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
+ {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f})));
return BuildWhileInstructionWithCustomInit(loop_state_shape_, data_init,
&builder);
}
@@ -677,11 +682,11 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
auto builder = HloComputation::Builder(TestName() + ".While");
auto one = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto v1 = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape_, one, {1}));
auto zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto v2 = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
@@ -689,7 +694,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
auto tuple2 = builder.AddInstruction(HloInstruction::CreateTuple({v2, v1}));
auto pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
auto data_init = builder.AddInstruction(HloInstruction::CreateTernary(
nested_tuple_shape_, HloOpcode::kTupleSelect, pred, tuple1, tuple2));
@@ -701,7 +706,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
auto builder = HloComputation::Builder(TestName() + ".While");
auto one = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto one_vec = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape_, one, {1}));
auto data_init =
@@ -714,11 +719,12 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
HloInstruction* BuildWhileInstruction_InitPointsToInterfering() {
auto builder = HloComputation::Builder(TestName() + ".While");
auto one = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto data_init = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape_, one, {1}));
- auto one_vec = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
+ auto one_vec = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
+ {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
// Take a reference to 'data_init' to make it interfere with while result.
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
data_shape_, HloOpcode::kAdd, data_init, one_vec));
@@ -750,7 +756,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
const bool nested =
ShapeUtil::Equal(loop_state_shape, nested_loop_state_shape_);
auto induction_var_init = builder->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
auto condition = module_->AddEmbeddedComputation(
BuildConditionComputation(loop_state_shape));
auto body = module_->AddEmbeddedComputation(
@@ -1252,7 +1258,6 @@ TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) {
auto loop_init = builder.AddInstruction(
HloInstruction::CreateTuple({iter_param, data_param, data_param}));
-
// Two while loops shares the same loop init tuple.
auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile(
loop_state_shape, condition1, body1, loop_init));
@@ -1310,7 +1315,7 @@ TEST_F(CopyInsertionTest, SwizzlingWhile) {
cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, loop_state_shape, "param"));
auto cond_constant = cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
cond_builder.AddInstruction(HloInstruction::CreateUnary(
cond_constant->shape(), HloOpcode::kNot, cond_constant));
HloComputation* condition =
@@ -1318,9 +1323,9 @@ TEST_F(CopyInsertionTest, SwizzlingWhile) {
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto xla_while = builder.AddInstruction(
@@ -1375,7 +1380,7 @@ TEST_F(CopyInsertionTest, SwizzlingWhileWithOneOp) {
cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, loop_state_shape, "param"));
auto cond_constant = cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
cond_builder.AddInstruction(HloInstruction::CreateUnary(
cond_constant->shape(), HloOpcode::kNot, cond_constant));
HloComputation* condition =
@@ -1383,9 +1388,9 @@ TEST_F(CopyInsertionTest, SwizzlingWhileWithOneOp) {
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto xla_while = builder.AddInstruction(
@@ -1435,7 +1440,7 @@ TEST_F(CopyInsertionTest, SwizzlingWhileSharedInput) {
cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, loop_state_shape, "param"));
auto cond_constant = cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
cond_builder.AddInstruction(HloInstruction::CreateUnary(
cond_constant->shape(), HloOpcode::kNot, cond_constant));
HloComputation* condition =
@@ -1443,7 +1448,7 @@ TEST_F(CopyInsertionTest, SwizzlingWhileSharedInput) {
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto tuple =
builder.AddInstruction(HloInstruction::CreateTuple({constant, constant}));
builder.AddInstruction(
@@ -1520,7 +1525,7 @@ TEST_F(CopyInsertionTest, SequentialWhiles) {
cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, loop_state_shape, "param"));
auto cond_constant = cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
cond_builder.AddInstruction(HloInstruction::CreateUnary(
cond_constant->shape(), HloOpcode::kNot, cond_constant));
HloComputation* condition =
@@ -1575,14 +1580,14 @@ TEST_F(CopyInsertionTest, WhileBodyWithConstantRoot) {
body_builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape_, "param"));
body_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(123.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0)));
HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
auto cond_builder = HloComputation::Builder("condition");
cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape_, "param"));
cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
HloComputation* condition =
module->AddEmbeddedComputation(cond_builder.Build());
@@ -1644,7 +1649,7 @@ std::unique_ptr<HloComputation> MakeTrivialCondition(const Shape& shape) {
builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "loop_state"));
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
builder.AddInstruction(HloInstruction::CreateUnary(
constant->shape(), HloOpcode::kNot, constant));
return builder.Build();
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 3479240610..c45d914e93 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -37,6 +37,7 @@ cc_library(
srcs = ["cpu_transfer_manager.cc"],
hdrs = ["cpu_transfer_manager.h"],
deps = [
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@@ -72,7 +73,7 @@ cc_library(
":ir_emitter",
":parallel_task_assignment",
":simple_orc_jit",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:protobuf_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
@@ -89,7 +90,6 @@ cc_library(
"//tensorflow/compiler/xla/service:dot_decomposer",
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:flatten_call_graph",
- "//tensorflow/compiler/xla/service:gather_expander",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_constant_folding",
"//tensorflow/compiler/xla/service:hlo_cse",
@@ -355,7 +355,7 @@ tf_cc_binary(
srcs = ["sample_harness.cc"],
deps = [
"//tensorflow/compiler/xla:array4d",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
@@ -717,7 +717,7 @@ tf_cc_test(
deps = [
":cpu_layout_assignment",
":target_machine_features_fake",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_layout",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
@@ -809,7 +809,7 @@ tf_cc_test(
":cpu_executable",
":parallel_task_assignment",
":target_machine_features_fake",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_layout",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
@@ -892,7 +892,7 @@ tf_cc_test(
srcs = ["cpu_copy_insertion_test.cc"],
deps = [
":cpu_copy_insertion",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc
index 375b017b09..547d4c696d 100644
--- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc
@@ -60,11 +60,11 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) {
auto builder = HloComputation::Builder(TestName());
// The input dimensions are in CNHW order.
auto input = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR4FromArray4D(Array4D<float>(
+ LiteralUtil::CreateR4FromArray4D(Array4D<float>(
kInputFeatureCount, kBatchSize, kInputSize, kInputSize))));
// The kernel dimensions are in OIHW order.
auto kernel = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR4FromArray4D(Array4D<float>(
+ LiteralUtil::CreateR4FromArray4D(Array4D<float>(
kOutputFeatureCount, kInputFeatureCount, kWindowSize, kWindowSize))));
ConvolutionDimensionNumbers dnums;
@@ -122,11 +122,11 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) {
auto builder = HloComputation::Builder(TestName());
// The input dimensions are in NHWC order.
auto input = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR4FromArray4D(Array4D<float>(
+ LiteralUtil::CreateR4FromArray4D(Array4D<float>(
kBatchSize, kInputSize, kInputSize, kInputFeatureCount))));
// The kernel dimensions are in HWIO order.
auto kernel = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR4FromArray4D(Array4D<float>(
+ LiteralUtil::CreateR4FromArray4D(Array4D<float>(
kWindowSize, kWindowSize, kInputFeatureCount, kOutputFeatureCount))));
ConvolutionDimensionNumbers dnums;
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index 55962ba70d..29fa29d33a 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -30,6 +30,7 @@ limitations under the License.
#include "llvm/ADT/Triple.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Mangler.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Verifier.h"
#include "llvm/Object/ObjectFile.h"
@@ -38,7 +39,7 @@ limitations under the License.
#include "llvm/Support/TargetSelect.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Target/TargetOptions.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
@@ -66,7 +67,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/dot_decomposer.h"
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
-#include "tensorflow/compiler/xla/service/gather_expander.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
@@ -297,8 +297,6 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/false);
pipeline.AddPass<CpuInstructionFusion>();
- pipeline.AddPass<GatherExpander>();
-
ReducePrecisionInsertion::AddPasses(
&pipeline, module->config().debug_options(),
ReducePrecisionInsertion::PassTiming::AFTER_FUSION);
@@ -607,7 +605,13 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
/*is_top_level_computation=*/true,
&module_sequence.at(entry_computation)));
- string function_name = llvm_ir::AsString(entry_function->getName());
+ string function_name = [&]() {
+ llvm::SmallVector<char, 40> function_name_vector;
+ llvm::Mangler::getNameWithPrefix(
+ function_name_vector, entry_function->getName(), jit->data_layout());
+ return string(function_name_vector.begin(), function_name_vector.end());
+ }();
+
string ir_module_string;
if (embed_ir_in_executable) {
ir_module_string = llvm_ir::DumpModuleToString(*llvm_module);
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc
index a05a269417..4db7fa446e 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
@@ -74,14 +74,14 @@ TEST_F(CpuCopyInsertionTest, WhileBodyWithConstantRoot) {
body_builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape_, "param"));
body_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(123.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0)));
HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
auto cond_builder = HloComputation::Builder("condition");
cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape_, "param"));
cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
HloComputation* condition =
module->AddEmbeddedComputation(cond_builder.Build());
@@ -114,7 +114,7 @@ TEST_F(CpuCopyInsertionTest, TupleCall) {
auto sub_param = sub_builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape_, "param"));
auto constant = sub_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(123.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0)));
auto add = sub_builder.AddInstruction(HloInstruction::CreateBinary(
scalar_shape_, HloOpcode::kAdd, sub_param, constant));
sub_builder.AddInstruction(
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
index 750310c633..991b14f17d 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
@@ -282,7 +282,7 @@ class OpcodeFusionTest : public InstructionFusionTest {
builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {}), "arg0"));
HloInstruction* one = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, arg0, one));
return module->AddEmbeddedComputation(builder.Build());
@@ -595,7 +595,7 @@ TEST_F(OpcodeFusionTest, MessOfFusileNodes) {
auto pad = builder.AddInstruction(HloInstruction::CreatePad(
ShapeUtil::MakeShape(S32, {5}), idx_choice,
builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(0))),
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))),
padding_config));
auto slice = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc
index 429fc7b786..3681d12d8d 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h"
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
index b877b29581..156166bf2b 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -180,7 +181,7 @@ Status CpuTransferManager::TransferLiteralFromOutfeed(
tensorflow::gtl::ArraySlice<int64> dimensions(
tensorflow::bit_cast<const int64*>(literal_shape.dimensions().data()),
literal_shape.dimensions().size());
- *literal = std::move(*Literal::CreateFromDimensions(
+ *literal = std::move(*LiteralUtil::CreateFromDimensions(
literal_shape.element_type(), dimensions));
TF_ASSIGN_OR_RETURN(Shape received_shape,
TransferArrayBufferFromOutfeed(
@@ -211,7 +212,7 @@ Status CpuTransferManager::TransferLiteralFromOutfeed(
tensorflow::bit_cast<const int64*>(
tuple_element_shape.dimensions().data()),
tuple_element_shape.dimensions().size());
- auto empty = Literal::CreateFromDimensions(
+ auto empty = LiteralUtil::CreateFromDimensions(
tuple_element_shape.element_type(), dimensions);
int64 size = GetByteSizeRequirement(tuple_element_shape);
buffer_data.push_back({empty->untyped_data(), size});
@@ -232,7 +233,7 @@ Status CpuTransferManager::TransferLiteralFromOutfeed(
for (int64 i = 0; i < literal_shape.tuple_shapes_size(); ++i) {
*elements[i]->mutable_shape_do_not_use() = received_shape.tuple_shapes(i);
}
- *literal = std::move(*Literal::MakeTupleOwned(std::move(elements)));
+ *literal = std::move(*LiteralUtil::MakeTupleOwned(std::move(elements)));
TF_RET_CHECK(ShapeUtil::Equal(literal->shape(), literal_shape));
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h
index 6dfc666f09..593575c0fd 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h
@@ -39,13 +39,14 @@ class CpuTransferManager : public GenericTransferManager {
Status TransferLiteralToInfeed(se::StreamExecutor* executor,
const LiteralSlice& literal) override;
- Status TransferBufferToInfeed(se::StreamExecutor* executor, int64 size,
- const void* source) override;
Status TransferLiteralFromOutfeed(se::StreamExecutor* executor,
const Shape& literal_shape,
Literal* literal) override;
private:
+ Status TransferBufferToInfeed(se::StreamExecutor* executor, int64 size,
+ const void* source);
+
// Transfers infeed data to device. InfeedBuffer->Done() must be
// called to clean up the memory allocated for InfeedBuffer.
StatusOr<cpu::runtime::XfeedBuffer*> TransferBufferToInfeedInternal(
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index 6b9a1d8c01..2ad41374d3 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -476,42 +476,111 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) {
return Status::OK();
}
+StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForMap(
+ HloMapInstruction* map, const llvm_ir::IrArray::Index& index) {
+ llvm::Function* mapped_ir_function =
+ FindOrDie(emitted_functions_, map->to_apply());
+ std::vector<llvm::Value*> parameter_addresses;
+ for (const HloInstruction* operand : map->operands()) {
+ const llvm_ir::IrArray& array = GetIrArrayFor(operand);
+ parameter_addresses.push_back(
+ array.EmitArrayElementAddress(index, &ir_builder_));
+ }
+ return EmitElementFunctionCall(mapped_ir_function, map->shape(),
+ parameter_addresses, "map_function");
+}
+
Status IrEmitter::HandleMap(HloInstruction* map) {
- gtl::ArraySlice<HloInstruction*> operands(map->operands());
- HloComputation* function = map->to_apply();
- // The called computation should have been emitted previously.
- llvm::Function* mapped_ir_function = FindOrDie(emitted_functions_, function);
-
- return EmitTargetElementLoop(map, [this, map, operands, mapped_ir_function](
- const llvm_ir::IrArray::Index& index) {
- std::vector<llvm::Value*> parameter_addresses;
- for (const HloInstruction* operand : operands) {
- const llvm_ir::IrArray& array = GetIrArrayFor(operand);
- parameter_addresses.push_back(
- array.EmitArrayElementAddress(index, &ir_builder_));
- }
- return EmitElementFunctionCall(mapped_ir_function, map->shape(),
- parameter_addresses, "map_function");
+ return EmitTargetElementLoop(map, [&](const llvm_ir::IrArray::Index& index) {
+ return EmitTargetElementLoopBodyForMap(Cast<HloMapInstruction>(map), index);
});
}
-Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) {
- auto operand = reduce_window->operand(0);
+StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduceWindow(
+ HloReduceWindowInstruction* reduce_window,
+ const llvm_ir::IrArray::Index& index) {
+ const HloInstruction* operand = reduce_window->operand(0);
const Window& window = reduce_window->window();
HloComputation* function = reduce_window->to_apply();
+ // The called computation should have been emitted previously.
+ llvm::Function* reducer_function = FindOrDie(emitted_functions_, function);
+
+ // We fold inputs into the accumulator and initialize it to
+ // the initial value on the reduce_window.
+ PrimitiveType operand_element_type = operand->shape().element_type();
+ llvm::Value* accumulator_address = llvm_ir::EmitAllocaAtFunctionEntry(
+ llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
+ "reduce_window_accumulator_address", &ir_builder_,
+ MinimumAlignmentForPrimitiveType(operand_element_type));
+ ir_builder_.CreateStore(
+ ir_builder_.CreateLoad(GetEmittedValueFor(reduce_window->operand(1))),
+ accumulator_address);
+
+ llvm_ir::ForLoopNest loops(IrName(reduce_window, "inner"), &ir_builder_);
+ std::vector<int64> window_size;
+ for (const auto& dim : window.dimensions()) {
+ window_size.push_back(dim.size());
+ }
+ const llvm_ir::IrArray::Index window_index = loops.AddLoopsForShape(
+ ShapeUtil::MakeShape(operand_element_type, window_size), "window");
+ CHECK_EQ(window_index.size(), index.size());
+
+ SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_);
+
+ llvm_ir::IrArray::Index input_index(ir_builder_.getInt64Ty(), index.size());
+ llvm::Value* in_bounds_condition = nullptr;
+ for (size_t i = 0; i < index.size(); ++i) {
+ llvm::Value* strided_index = ir_builder_.CreateNSWMul(
+ index[i], ir_builder_.getInt64(window.dimensions(i).stride()));
+ input_index[i] = ir_builder_.CreateNSWSub(
+ ir_builder_.CreateNSWAdd(strided_index, window_index[i]),
+ ir_builder_.getInt64(window.dimensions(i).padding_low()));
+
+ // We need to check if 0 <= input_index[i] < bound, as otherwise we are in
+ // the padding so that we can skip the computation. That is equivalent to
+ // input_index[i] < bound as an *unsigned* comparison, since a negative
+ // value will wrap to a large positive value.
+ llvm::Value* index_condition = ir_builder_.CreateICmpULT(
+ input_index[i],
+ ir_builder_.getInt64(ShapeUtil::GetDimension(operand->shape(), i)));
+ if (in_bounds_condition == nullptr) {
+ in_bounds_condition = index_condition;
+ } else {
+ in_bounds_condition =
+ ir_builder_.CreateAnd(in_bounds_condition, index_condition);
+ }
+ }
+ CHECK(in_bounds_condition != nullptr);
+
+ llvm_ir::LlvmIfData if_data =
+ llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &ir_builder_);
+ SetToFirstInsertPoint(if_data.true_block, &ir_builder_);
+
+ // We are not in the padding, so carry out the computation.
+ llvm_ir::IrArray input_array(GetIrArrayFor(operand));
+ llvm::Value* input_value_address =
+ input_array.EmitArrayElementAddress(input_index, &ir_builder_);
+ llvm::Value* result = EmitElementFunctionCall(
+ reducer_function, reduce_window->shape(),
+ {accumulator_address, input_value_address}, "reducer_function");
+ ir_builder_.CreateStore(result, accumulator_address);
+
+ SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_);
+ return ir_builder_.CreateLoad(accumulator_address);
+}
+
+Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) {
TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
- /*instruction=*/*reduce_window, /*operands=*/{operand},
+ /*instruction=*/*reduce_window,
+ /*operands=*/{reduce_window->operand(0)},
/*supported_types=*/{F32, BF16, S32}));
// TODO(b/31410564): Implement dilation for reduce-window.
- if (window_util::HasDilation(window)) {
+ if (window_util::HasDilation(reduce_window->window())) {
return Unimplemented(
"Dilation for ReduceWindow is not implemented on CPU.");
}
- // The called computation should have been emitted previously.
- llvm::Function* reducer_function = FindOrDie(emitted_functions_, function);
-
// Pseudo code for reduce window:
//
// for (coordinates O in the output)
@@ -526,73 +595,9 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) {
// This is completely un-optimized and just here to have something
// that works.
return EmitTargetElementLoop(
- reduce_window, [this, reduce_window, operand, window,
- reducer_function](const llvm_ir::IrArray::Index& index) {
- // We fold inputs into the accumulator and initialize it to
- // the initial value on the reduce_window.
- PrimitiveType operand_element_type = operand->shape().element_type();
- llvm::Value* accumulator_address = llvm_ir::EmitAllocaAtFunctionEntry(
- llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
- "reduce_window_accumulator_address", &ir_builder_,
- MinimumAlignmentForPrimitiveType(operand_element_type));
- ir_builder_.CreateStore(ir_builder_.CreateLoad(GetEmittedValueFor(
- reduce_window->operand(1))),
- accumulator_address);
-
- llvm_ir::ForLoopNest loops(IrName(reduce_window, "inner"),
- &ir_builder_);
- std::vector<int64> window_size;
- for (const auto& dim : window.dimensions()) {
- window_size.push_back(dim.size());
- }
- const llvm_ir::IrArray::Index window_index = loops.AddLoopsForShape(
- ShapeUtil::MakeShape(operand_element_type, window_size), "window");
- CHECK_EQ(window_index.size(), index.size());
-
- SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_);
-
- llvm_ir::IrArray::Index input_index(ir_builder_.getInt64Ty(),
- index.size());
- llvm::Value* in_bounds_condition = nullptr;
- for (size_t i = 0; i < index.size(); ++i) {
- llvm::Value* strided_index = ir_builder_.CreateNSWMul(
- index[i], ir_builder_.getInt64(window.dimensions(i).stride()));
- input_index[i] = ir_builder_.CreateNSWSub(
- ir_builder_.CreateNSWAdd(strided_index, window_index[i]),
- ir_builder_.getInt64(window.dimensions(i).padding_low()));
-
- // We need to check if 0 <= input_index[i] < bound, as
- // otherwise we are in the padding so that we can skip the
- // computation. That is equivalent to input_index[i] < bound
- // as an *unsigned* comparison, since a negative value will
- // wrap to a large positive value.
- llvm::Value* index_condition = ir_builder_.CreateICmpULT(
- input_index[i], ir_builder_.getInt64(ShapeUtil::GetDimension(
- operand->shape(), i)));
- if (in_bounds_condition == nullptr) {
- in_bounds_condition = index_condition;
- } else {
- in_bounds_condition =
- ir_builder_.CreateAnd(in_bounds_condition, index_condition);
- }
- }
- CHECK(in_bounds_condition != nullptr);
-
- llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
- in_bounds_condition, "in-bounds", &ir_builder_);
- SetToFirstInsertPoint(if_data.true_block, &ir_builder_);
-
- // We are not in the padding, so carry out the computation.
- llvm_ir::IrArray input_array(GetIrArrayFor(operand));
- llvm::Value* input_value_address =
- input_array.EmitArrayElementAddress(input_index, &ir_builder_);
- llvm::Value* result = EmitElementFunctionCall(
- reducer_function, reduce_window->shape(),
- {accumulator_address, input_value_address}, "reducer_function");
- ir_builder_.CreateStore(result, accumulator_address);
-
- SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_);
- return ir_builder_.CreateLoad(accumulator_address);
+ reduce_window, [&](const llvm_ir::IrArray::Index& index) {
+ return EmitTargetElementLoopBodyForReduceWindow(
+ Cast<HloReduceWindowInstruction>(reduce_window), index);
});
}
@@ -821,17 +826,157 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
target_machine_features_);
}
+StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForConvolution(
+ HloConvolutionInstruction* convolution,
+ const llvm_ir::IrArray::Index& index) {
+ const HloInstruction* lhs = convolution->operand(0);
+ const HloInstruction* rhs = convolution->operand(1);
+ const Window& window = convolution->window();
+
+ const ConvolutionDimensionNumbers& dnums =
+ convolution->convolution_dimension_numbers();
+ int num_spatial_dims = dnums.output_spatial_dimensions_size();
+ std::vector<llvm::Value*> output_spatial(num_spatial_dims);
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ output_spatial[i] = index[dnums.output_spatial_dimensions(i)];
+ }
+ llvm::Value* output_feature = index[dnums.output_feature_dimension()];
+ llvm::Value* batch = index[dnums.output_batch_dimension()];
+
+ // We will accumulate the products into this sum to calculate the output entry
+ // at the given index.
+ PrimitiveType lhs_element_type = lhs->shape().element_type();
+ llvm::Type* lhs_llvm_type =
+ llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_);
+ llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry(
+ lhs_llvm_type, "convolution_sum_address", &ir_builder_,
+ MinimumAlignmentForPrimitiveType(lhs_element_type));
+ llvm::Value* constant_zero = llvm::Constant::getNullValue(lhs_llvm_type);
+ ir_builder_.CreateStore(constant_zero, sum_address);
+
+ llvm_ir::ForLoopNest loops(IrName(convolution, "inner"), &ir_builder_);
+ std::vector<llvm::Value*> kernel_spatial(num_spatial_dims);
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ kernel_spatial[i] =
+ loops
+ .AddLoop(
+ 0, rhs->shape().dimensions(dnums.kernel_spatial_dimensions(i)),
+ tensorflow::strings::StrCat("k", i))
+ ->GetIndVarValue();
+ }
+ llvm::Value* input_feature =
+ loops
+ .AddLoop(0, lhs->shape().dimensions(dnums.input_feature_dimension()),
+ "iz")
+ ->GetIndVarValue();
+
+ SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_);
+
+ // Calculate the spatial index in the input array, taking striding, dilation
+ // and padding into account. An index in the padding will be out of the bounds
+ // of the array.
+ const auto calculate_input_index = [this](llvm::Value* output_index,
+ llvm::Value* kernel_index,
+ const WindowDimension& window_dim) {
+ llvm::Value* strided_index = ir_builder_.CreateNSWMul(
+ output_index, ir_builder_.getInt64(window_dim.stride()));
+ llvm::Value* dilated_kernel_index = ir_builder_.CreateNSWMul(
+ kernel_index, ir_builder_.getInt64(window_dim.window_dilation()));
+ return ir_builder_.CreateNSWSub(
+ ir_builder_.CreateNSWAdd(strided_index, dilated_kernel_index),
+ ir_builder_.getInt64(window_dim.padding_low()));
+ };
+ std::vector<llvm::Value*> input_spatial(num_spatial_dims);
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ input_spatial[i] = calculate_input_index(
+ output_spatial[i], kernel_spatial[i], window.dimensions(i));
+ }
+
+ // We need to check if 0 <= input dim < bound, as otherwise we are in the
+ // padding so that we can skip the computation. That is equivalent to input
+ // dim < bound as an *unsigned* comparison, since a negative value will wrap
+ // to a large positive value. The input dim is dilated, so we need to dilate
+ // the bound as well to match.
+
+ // Also need to check that the input coordinates are not in one of the
+ // holes created by base dilation.
+ const auto not_in_hole = [&](llvm::Value* input_index, int64 base_dilation) {
+ llvm::Value* remainder = ir_builder_.CreateSRem(
+ input_index, ir_builder_.getInt64(base_dilation));
+ return ir_builder_.CreateICmpEQ(remainder, ir_builder_.getInt64(0));
+ };
+
+ llvm::Value* in_bounds_condition = ir_builder_.getInt1(true);
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ llvm::ConstantInt* input_bound =
+ ir_builder_.getInt64(window_util::DilatedBound(
+ lhs->shape().dimensions(dnums.input_spatial_dimensions(i)),
+ window.dimensions(i).base_dilation()));
+ llvm::Value* dim_in_bound =
+ ir_builder_.CreateICmpULT(input_spatial[i], input_bound);
+ llvm::Value* dim_not_in_hole =
+ not_in_hole(input_spatial[i], window.dimensions(i).base_dilation());
+ llvm::Value* dim_ok = ir_builder_.CreateAnd(dim_in_bound, dim_not_in_hole);
+ in_bounds_condition = ir_builder_.CreateAnd(in_bounds_condition, dim_ok);
+ }
+
+ // Now we need to map the dilated base coordinates back to the actual
+ // data indices on the lhs.
+ const auto undilate = [&](llvm::Value* input_index, int64 base_dilation) {
+ return ir_builder_.CreateSDiv(input_index,
+ ir_builder_.getInt64(base_dilation));
+ };
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ input_spatial[i] =
+ undilate(input_spatial[i], window.dimensions(i).base_dilation());
+ }
+
+ llvm_ir::LlvmIfData if_data =
+ llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &ir_builder_);
+ SetToFirstInsertPoint(if_data.true_block, &ir_builder_);
+
+ // We are not in the padding, so carry out the computation.
+ int num_dims = num_spatial_dims + 2;
+ llvm_ir::IrArray::Index input_index(ir_builder_.getInt64Ty(), num_dims);
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ input_index[dnums.input_spatial_dimensions(i)] = input_spatial[i];
+ }
+ input_index[dnums.input_feature_dimension()] = input_feature;
+ input_index[dnums.input_batch_dimension()] = batch;
+
+ llvm_ir::IrArray kernel_array(GetIrArrayFor(rhs));
+ llvm_ir::IrArray::Index kernel_index(ir_builder_.getInt64Ty(), num_dims);
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ kernel_index[dnums.kernel_spatial_dimensions(i)] =
+ window.dimensions(i).window_reversal()
+ ? ir_builder_.CreateNSWSub(
+ ir_builder_.getInt64(window.dimensions(i).size() - 1),
+ kernel_spatial[i])
+ : kernel_spatial[i];
+ }
+
+ kernel_index[dnums.kernel_input_feature_dimension()] = input_feature;
+ kernel_index[dnums.kernel_output_feature_dimension()] = output_feature;
+
+ llvm_ir::IrArray input_array(GetIrArrayFor(lhs));
+ llvm::Value* product = ir_builder_.CreateFMul(
+ input_array.EmitReadArrayElement(input_index, &ir_builder_),
+ kernel_array.EmitReadArrayElement(kernel_index, &ir_builder_));
+ llvm::Value* sum =
+ ir_builder_.CreateFAdd(ir_builder_.CreateLoad(sum_address), product);
+ ir_builder_.CreateStore(sum, sum_address);
+
+ SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_);
+ return ir_builder_.CreateLoad(sum_address);
+}
+
Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
auto lhs = convolution->operand(0);
auto rhs = convolution->operand(1);
- const auto& window = convolution->window();
TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
/*instruction=*/*convolution, /*operands=*/{lhs, rhs},
/*supported_types=*/{F16, F32, C64}));
- const ConvolutionDimensionNumbers& dnums =
- convolution->convolution_dimension_numbers();
-
// TODO(tonywy): Add PotentiallyImplementedAsMKLCovolution to support
// different data layouts.
if (PotentiallyImplementedAsEigenConvolution(*convolution,
@@ -988,150 +1133,9 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
// See the description of convolution in the XLA documentation for the pseudo
// code for convolution.
return EmitTargetElementLoop(
- convolution, [this, convolution, lhs, rhs, window,
- dnums](const llvm_ir::IrArray::Index& index) {
- int num_spatial_dims = dnums.output_spatial_dimensions_size();
- std::vector<llvm::Value*> output_spatial(num_spatial_dims);
- for (int i = 0; i < num_spatial_dims; ++i) {
- output_spatial[i] = index[dnums.output_spatial_dimensions(i)];
- }
- llvm::Value* output_feature = index[dnums.output_feature_dimension()];
- llvm::Value* batch = index[dnums.output_batch_dimension()];
-
- // We will accumulate the products into this sum to calculate
- // the output entry at the given index.
- PrimitiveType lhs_element_type = lhs->shape().element_type();
- llvm::Type* lhs_llvm_type =
- llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_);
- llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry(
- lhs_llvm_type, "convolution_sum_address", &ir_builder_,
- MinimumAlignmentForPrimitiveType(lhs_element_type));
- llvm::Value* constant_zero =
- llvm::Constant::getNullValue(lhs_llvm_type);
- ir_builder_.CreateStore(constant_zero, sum_address);
-
- llvm_ir::ForLoopNest loops(IrName(convolution, "inner"), &ir_builder_);
- std::vector<llvm::Value*> kernel_spatial(num_spatial_dims);
- for (int i = 0; i < num_spatial_dims; ++i) {
- kernel_spatial[i] =
- loops
- .AddLoop(0,
- rhs->shape().dimensions(
- dnums.kernel_spatial_dimensions(i)),
- tensorflow::strings::StrCat("k", i))
- ->GetIndVarValue();
- }
- llvm::Value* input_feature =
- loops
- .AddLoop(
- 0, lhs->shape().dimensions(dnums.input_feature_dimension()),
- "iz")
- ->GetIndVarValue();
-
- SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_);
-
- // Calculate the spatial index in the input array, taking striding,
- // dilation and padding into account. An index in the padding will be
- // out of the bounds of the array.
- const auto calculate_input_index =
- [this](llvm::Value* output_index, llvm::Value* kernel_index,
- const WindowDimension& window_dim) {
- llvm::Value* strided_index = ir_builder_.CreateNSWMul(
- output_index, ir_builder_.getInt64(window_dim.stride()));
- llvm::Value* dilated_kernel_index = ir_builder_.CreateNSWMul(
- kernel_index,
- ir_builder_.getInt64(window_dim.window_dilation()));
- return ir_builder_.CreateNSWSub(
- ir_builder_.CreateNSWAdd(strided_index, dilated_kernel_index),
- ir_builder_.getInt64(window_dim.padding_low()));
- };
- std::vector<llvm::Value*> input_spatial(num_spatial_dims);
- for (int i = 0; i < num_spatial_dims; ++i) {
- input_spatial[i] = calculate_input_index(
- output_spatial[i], kernel_spatial[i], window.dimensions(i));
- }
-
- // We need to check if 0 <= input dim < bound, as otherwise we are in
- // the padding so that we can skip the computation. That is equivalent
- // to input dim < bound as an *unsigned* comparison, since a negative
- // value will wrap to a large positive value. The input dim is dilated,
- // so we need to dilate the bound as well to match.
-
- // Also need to check that the input coordinates are not in one of the
- // holes created by base dilation.
- const auto not_in_hole = [&](llvm::Value* input_index,
- int64 base_dilation) {
- llvm::Value* remainder = ir_builder_.CreateSRem(
- input_index, ir_builder_.getInt64(base_dilation));
- return ir_builder_.CreateICmpEQ(remainder, ir_builder_.getInt64(0));
- };
-
- llvm::Value* in_bounds_condition = ir_builder_.getInt1(true);
- for (int i = 0; i < num_spatial_dims; ++i) {
- llvm::ConstantInt* input_bound =
- ir_builder_.getInt64(window_util::DilatedBound(
- lhs->shape().dimensions(dnums.input_spatial_dimensions(i)),
- window.dimensions(i).base_dilation()));
- llvm::Value* dim_in_bound =
- ir_builder_.CreateICmpULT(input_spatial[i], input_bound);
- llvm::Value* dim_not_in_hole = not_in_hole(
- input_spatial[i], window.dimensions(i).base_dilation());
- llvm::Value* dim_ok =
- ir_builder_.CreateAnd(dim_in_bound, dim_not_in_hole);
- in_bounds_condition =
- ir_builder_.CreateAnd(in_bounds_condition, dim_ok);
- }
-
- // Now we need to map the dilated base coordinates back to the actual
- // data indices on the lhs.
- const auto undilate = [&](llvm::Value* input_index,
- int64 base_dilation) {
- return ir_builder_.CreateSDiv(input_index,
- ir_builder_.getInt64(base_dilation));
- };
- for (int i = 0; i < num_spatial_dims; ++i) {
- input_spatial[i] =
- undilate(input_spatial[i], window.dimensions(i).base_dilation());
- }
-
- llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
- in_bounds_condition, "in-bounds", &ir_builder_);
- SetToFirstInsertPoint(if_data.true_block, &ir_builder_);
-
- // We are not in the padding, so carry out the computation.
- int num_dims = num_spatial_dims + 2;
- llvm_ir::IrArray::Index input_index(ir_builder_.getInt64Ty(), num_dims);
- for (int i = 0; i < num_spatial_dims; ++i) {
- input_index[dnums.input_spatial_dimensions(i)] = input_spatial[i];
- }
- input_index[dnums.input_feature_dimension()] = input_feature;
- input_index[dnums.input_batch_dimension()] = batch;
-
- llvm_ir::IrArray kernel_array(GetIrArrayFor(rhs));
- llvm_ir::IrArray::Index kernel_index(ir_builder_.getInt64Ty(),
- num_dims);
- for (int i = 0; i < num_spatial_dims; ++i) {
- kernel_index[dnums.kernel_spatial_dimensions(i)] =
- window.dimensions(i).window_reversal()
- ? ir_builder_.CreateNSWSub(
- ir_builder_.getInt64(window.dimensions(i).size() - 1),
- kernel_spatial[i])
- : kernel_spatial[i];
- }
-
- kernel_index[dnums.kernel_input_feature_dimension()] = input_feature;
- kernel_index[dnums.kernel_output_feature_dimension()] = output_feature;
-
- llvm_ir::IrArray input_array(GetIrArrayFor(lhs));
- llvm::Value* product = ir_builder_.CreateFMul(
- input_array.EmitReadArrayElement(input_index, &ir_builder_),
- kernel_array.EmitReadArrayElement(kernel_index, &ir_builder_));
- llvm::Value* sum = ir_builder_.CreateFAdd(
- ir_builder_.CreateLoad(sum_address), product);
- ir_builder_.CreateStore(sum, sum_address);
-
- SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_);
- return ir_builder_.CreateLoad(sum_address);
+ convolution, [&](const llvm_ir::IrArray::Index& index) {
+ return EmitTargetElementLoopBodyForConvolution(
+ Cast<HloConvolutionInstruction>(convolution), index);
});
}
@@ -1768,6 +1772,64 @@ StatusOr<bool> IrEmitter::EmitVectorizedReduce(
return true;
}
+StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduce(
+ HloReduceInstruction* reduce, const llvm_ir::IrArray::Index& index) {
+ const HloInstruction* arg = reduce->mutable_operand(0);
+ const HloInstruction* init_value = reduce->mutable_operand(1);
+ gtl::ArraySlice<int64> dimensions(reduce->dimensions());
+ HloComputation* function = reduce->to_apply();
+ // The called computation should have been emitted previously.
+ llvm::Function* reducer_function = FindOrDie(emitted_functions_, function);
+
+ // Initialize an accumulator with init_value.
+ PrimitiveType accumulator_type = reduce->shape().element_type();
+ llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry(
+ llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_), "accumulator",
+ &ir_builder_, MinimumAlignmentForPrimitiveType(accumulator_type));
+ llvm::Value* init_value_addr = GetEmittedValueFor(init_value);
+ llvm::Value* load_init_value = ir_builder_.CreateLoad(init_value_addr);
+ ir_builder_.CreateStore(load_init_value, accumulator_addr);
+
+ // The enclosing loops go over all the target elements. Now we have to compute
+ // the actual target element. For this, we build a new loop nest to iterate
+ // over all the reduction dimensions in the argument.
+ // AddLoopsForShapeOnDimensions will return an Index where induction Value*s
+ // are placed for each dimension in dimensions, and all the rest are nullptrs.
+ llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), &ir_builder_);
+ const llvm_ir::IrArray::Index reduced_dims_index =
+ loops.AddLoopsForShapeOnDimensions(arg->shape(), dimensions,
+ "reduction_dim");
+
+ SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_);
+
+ // Build a full index for the input argument, using reduced_dims_index as the
+ // base. In reduced_dims_index only the reduction dimensions are filled in. We
+ // fill in the rest of the dimensions with induction Value*s taken from
+ // 'index' which iterates over the target array. See the high-level
+ // description in the XLA documentation for details.
+ llvm_ir::IrArray arg_array(GetIrArrayFor(arg));
+ llvm_ir::IrArray::Index input_index = reduced_dims_index;
+ llvm_ir::IrArray::Index::const_iterator it = index.begin();
+
+ for (size_t i = 0; i < input_index.size(); ++i) {
+ if (input_index[i] == nullptr) {
+ input_index[i] = *it++;
+ }
+ }
+ CHECK(index.end() == it);
+
+ // Apply the reduction function to the loaded value.
+ llvm::Value* input_address =
+ arg_array.EmitArrayElementAddress(input_index, &ir_builder_);
+ llvm::Value* result = EmitElementFunctionCall(
+ reducer_function, reduce->shape(), {accumulator_addr, input_address},
+ "reduce_function");
+ ir_builder_.CreateStore(result, accumulator_addr);
+
+ SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_);
+ return ir_builder_.CreateLoad(accumulator_addr);
+}
+
Status IrEmitter::HandleReduce(HloInstruction* reduce) {
auto arg = reduce->mutable_operand(0);
auto init_value = reduce->mutable_operand(1);
@@ -1789,61 +1851,11 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) {
}
}
- // The called computation should have been emitted previously.
- llvm::Function* reducer_function = FindOrDie(emitted_functions_, function);
- return EmitTargetElementLoop(
- reduce, [this, reduce, arg, init_value, dimensions,
- reducer_function](const llvm_ir::IrArray::Index& index) {
- // Initialize an accumulator with init_value.
- PrimitiveType accumulator_type = reduce->shape().element_type();
- llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry(
- llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_),
- "accumulator", &ir_builder_,
- MinimumAlignmentForPrimitiveType(accumulator_type));
- llvm::Value* init_value_addr = GetEmittedValueFor(init_value);
- llvm::Value* load_init_value = ir_builder_.CreateLoad(init_value_addr);
- ir_builder_.CreateStore(load_init_value, accumulator_addr);
-
- // The enclosing loops go over all the target elements. Now we have to
- // compute the actual target element. For this, we build a new loop nest
- // to iterate over all the reduction dimensions in the argument.
- // AddLoopsForShapeOnDimensions will return an Index where induction
- // Value*s are placed for each dimension in dimensions, and all the rest
- // are nullptrs.
- llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), &ir_builder_);
- const llvm_ir::IrArray::Index reduced_dims_index =
- loops.AddLoopsForShapeOnDimensions(arg->shape(), dimensions,
- "reduction_dim");
-
- SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_);
-
- // Build a full index for the input argument, using reduced_dims_index
- // as the base. In reduced_dims_index only the reduction dimensions are
- // filled in. We fill in the rest of the dimensions with induction
- // Value*s taken from 'index' which iterates over the target array.
- // See the high-level description in the XLA documentation for details.
- llvm_ir::IrArray arg_array(GetIrArrayFor(arg));
- llvm_ir::IrArray::Index input_index = reduced_dims_index;
- llvm_ir::IrArray::Index::const_iterator it = index.begin();
-
- for (size_t i = 0; i < input_index.size(); ++i) {
- if (input_index[i] == nullptr) {
- input_index[i] = *it++;
- }
- }
- CHECK(index.end() == it);
-
- // Apply the reduction function to the loaded value.
- llvm::Value* input_address =
- arg_array.EmitArrayElementAddress(input_index, &ir_builder_);
- llvm::Value* result = EmitElementFunctionCall(
- reducer_function, reduce->shape(),
- {accumulator_addr, input_address}, "reduce_function");
- ir_builder_.CreateStore(result, accumulator_addr);
-
- SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_);
- return ir_builder_.CreateLoad(accumulator_addr);
- });
+ return EmitTargetElementLoop(reduce,
+ [&](const llvm_ir::IrArray::Index& index) {
+ return EmitTargetElementLoopBodyForReduce(
+ Cast<HloReduceInstruction>(reduce), index);
+ });
}
Status IrEmitter::HandleSend(HloInstruction* send) {
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index 3089f6451e..419f19c24d 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
@@ -514,6 +515,17 @@ class IrEmitter : public DfsHloVisitorWithDefault {
// Returns the number of bytes within the shape.
int64 ByteSizeOf(const Shape& shape) const;
+ StatusOr<llvm::Value*> EmitTargetElementLoopBodyForMap(
+ HloMapInstruction* map, const llvm_ir::IrArray::Index& index);
+ StatusOr<llvm::Value*> EmitTargetElementLoopBodyForReduceWindow(
+ HloReduceWindowInstruction* reduce_window,
+ const llvm_ir::IrArray::Index& index);
+ StatusOr<llvm::Value*> EmitTargetElementLoopBodyForConvolution(
+ HloConvolutionInstruction* convolution,
+ const llvm_ir::IrArray::Index& index);
+ StatusOr<llvm::Value*> EmitTargetElementLoopBodyForReduce(
+ HloReduceInstruction* reduce, const llvm_ir::IrArray::Index& index);
+
enum class XfeedKind {
kInfeed,
kOutfeed,
diff --git a/tensorflow/compiler/xla/service/cpu/sample_harness.cc b/tensorflow/compiler/xla/service/cpu/sample_harness.cc
index 7e792a82b8..d9e8dcaed9 100644
--- a/tensorflow/compiler/xla/service/cpu/sample_harness.cc
+++ b/tensorflow/compiler/xla/service/cpu/sample_harness.cc
@@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -38,12 +38,13 @@ int main(int argc, char** argv) {
// Transfer parameters.
std::unique_ptr<xla::Literal> param0_literal =
- xla::Literal::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
+ xla::LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
std::unique_ptr<xla::GlobalData> param0_data =
client->TransferToServer(*param0_literal).ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> param1_literal = xla::Literal::CreateR2<float>(
- {{3.1f, 4.2f, 7.3f, 9.5f}, {1.1f, 2.2f, 3.3f, 4.4f}});
+ std::unique_ptr<xla::Literal> param1_literal =
+ xla::LiteralUtil::CreateR2<float>(
+ {{3.1f, 4.2f, 7.3f, 9.5f}, {1.1f, 2.2f, 3.3f, 4.4f}});
std::unique_ptr<xla::GlobalData> param1_data =
client->TransferToServer(*param1_literal).ConsumeValueOrDie();
diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD
index 66ae5ef0f6..b4c33e2f6c 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD
@@ -40,7 +40,7 @@ tf_cc_test(
name = "cpu_fusion_test",
srcs = ["cpu_fusion_test.cc"],
deps = [
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
@@ -82,7 +82,7 @@ tf_cc_test(
name = "cpu_noalias_test",
srcs = ["cpu_noalias_test.cc"],
deps = [
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
@@ -128,7 +128,7 @@ tf_cc_test(
name = "cpu_infeed_test",
srcs = ["cpu_infeed_test.cc"],
deps = [
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test_helpers",
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h b/tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h
index 7c8d07a10b..77b3a0301f 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h
@@ -22,7 +22,7 @@ namespace xla {
namespace cpu {
// Tests that verify IR emitted by the CPU backend is as expected.
-class CpuCodegenTest : public LLVMIRGenTestBase {};
+class CpuCodegenTest : public LlvmIrGenTestBase {};
} // namespace cpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc
index 1d4bf483ae..00a7aa2ad2 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc
@@ -40,7 +40,7 @@ class CpuExternalConstantsTest : public CpuCodegenTest {
HloInstruction* constant =
builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2FromArray2D(backing_array)));
+ LiteralUtil::CreateR2FromArray2D(backing_array)));
HloInstruction* param =
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
builder.AddInstruction(
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc
index 783b2820e9..d98856fdbf 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <utility>
#include <vector>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -43,8 +43,8 @@ class CpuFusionTest : public HloTestBase {
TEST_F(CpuFusionTest, FuseTwoElementwiseOps) {
auto builder = HloComputation::Builder(TestName());
- auto input_literal1 = Literal::CreateR1<float>({1.0, 2.0, 3.0});
- auto input_literal2 = Literal::CreateR1<float>({-2.0, -42.0, 2.0});
+ auto input_literal1 = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0});
+ auto input_literal2 = LiteralUtil::CreateR1<float>({-2.0, -42.0, 2.0});
Shape vshape = input_literal1->shape();
auto input1 = builder.AddInstruction(
@@ -83,7 +83,7 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) {
TEST_F(CpuFusionTest, FuseElementwiseOpChain) {
auto builder = HloComputation::Builder(TestName());
- auto input_literal = Literal::CreateR1<float>({-1.5, -2.5, -3.0});
+ auto input_literal = LiteralUtil::CreateR1<float>({-1.5, -2.5, -3.0});
Shape vshape = input_literal->shape();
auto input = builder.AddInstruction(
@@ -99,7 +99,7 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) {
auto two = builder.AddInstruction(HloInstruction::CreateBroadcast(
vshape,
builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))),
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0))),
{}));
builder.AddInstruction(
HloInstruction::CreateBinary(vshape, HloOpcode::kMultiply, two, floor));
@@ -134,7 +134,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusableInstruction) {
// middle.
auto module = CreateNewModule();
auto builder = HloComputation::Builder(TestName());
- auto input_literal = Literal::CreateR1<float>({-1.5, -2.5, -3.0});
+ auto input_literal = LiteralUtil::CreateR1<float>({-1.5, -2.5, -3.0});
Shape vshape = input_literal->shape();
auto input = builder.AddInstruction(
@@ -166,7 +166,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusableInstruction) {
ShapeUtil::MakeShape(F32, {6, 1}), concatenate)),
/*init_value=*/
builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0))),
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0))),
/*dimensions_to_reduce=*/{1}, add_f32));
auto exp = builder.AddInstruction(
@@ -176,7 +176,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusableInstruction) {
auto two = builder.AddInstruction(HloInstruction::CreateBroadcast(
cshape,
builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))),
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0))),
{}));
builder.AddInstruction(
HloInstruction::CreateBinary(cshape, HloOpcode::kMultiply, two, floor));
@@ -231,7 +231,7 @@ TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) {
// operand vectors. Test for this problem by counting the number of nodes in
// each fusion instruction to ensure that negate is not duplicated.
auto builder = HloComputation::Builder(TestName());
- auto input_literal = Literal::CreateR1<float>({1.0, 2.0, 3.0});
+ auto input_literal = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0});
Shape vshape = input_literal->shape();
auto constant = builder.AddInstruction(
@@ -292,10 +292,10 @@ TEST_F(CpuFusionTest, DoNotDuplicateExpensiveOps) {
// computation. The duplication is caused by the other use of exp2 in the
// tuple.
auto builder = HloComputation::Builder(TestName());
- auto input_literal1 = Literal::CreateR1<float>({1.0, 2.0, 3.0});
- auto input_literal2 = Literal::CreateR1<float>({-2.0, -42.0, 2.0});
+ auto input_literal1 = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0});
+ auto input_literal2 = LiteralUtil::CreateR1<float>({-2.0, -42.0, 2.0});
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
Shape shape = constant->shape();
auto exp1 = builder.AddInstruction(
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc
index ea7e479d66..0d45918d09 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test_helpers.h"
@@ -58,52 +58,52 @@ class InfeedTest : public ClientLibraryTestBase {
};
TEST_F(InfeedTest, SingleInfeedR0Bool) {
- TestInfeedRoundTrip(*Literal::CreateR0<bool>(true));
+ TestInfeedRoundTrip(*LiteralUtil::CreateR0<bool>(true));
}
TEST_F(InfeedTest, SingleInfeedR1U32) {
- TestInfeedRoundTrip(*Literal::CreateR1<uint32>({1, 2, 3}));
+ TestInfeedRoundTrip(*LiteralUtil::CreateR1<uint32>({1, 2, 3}));
}
TEST_F(InfeedTest, SingleInfeedR2F32) {
- TestInfeedRoundTrip(*Literal::CreateR2F32Linspace(0.0, 1.0, 128, 64));
+ TestInfeedRoundTrip(*LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64));
}
TEST_F(InfeedTest, SingleInfeedR3F32) {
TestInfeedRoundTrip(
- *Literal::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
- {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
+ *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
+ {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
}
TEST_F(InfeedTest, SingleInfeedR3F32DifferentLayout) {
const Layout r3_dim0minor = LayoutUtil::MakeLayout({0, 1, 2});
const Layout r3_dim0major = LayoutUtil::MakeLayout({2, 1, 0});
- TestInfeedRoundTrip(
- *Literal::CreateR3WithLayout({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
- {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
- r3_dim0minor));
+ TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout(
+ {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
+ {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
+ r3_dim0minor));
- TestInfeedRoundTrip(
- *Literal::CreateR3WithLayout({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
- {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
- r3_dim0major));
+ TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout(
+ {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
+ {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
+ r3_dim0major));
}
TEST_F(InfeedTest, SingleInfeedR4S32) {
- TestInfeedRoundTrip(*Literal::CreateR4(
+ TestInfeedRoundTrip(*LiteralUtil::CreateR4(
{{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}},
{{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}}));
}
TEST_F(InfeedTest, SingleInfeedTuple) {
TestInfeedRoundTrip(
- *Literal::MakeTuple({Literal::CreateR1<uint32>({1, 2, 3}).get(),
- Literal::CreateR0<bool>(false).get()}));
+ *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<uint32>({1, 2, 3}).get(),
+ LiteralUtil::CreateR0<bool>(false).get()}));
}
TEST_F(InfeedTest, SingleInfeedEmptyTuple) {
- TestInfeedRoundTrip(*Literal::MakeTuple({}));
+ TestInfeedRoundTrip(*LiteralUtil::MakeTuple({}));
}
// Tests Infeed operation used in a while loop, as in the code below. The
@@ -156,13 +156,16 @@ TEST_F(InfeedTest, DISABLED_SingleInfeedInWhile) {
});
// Send 5 Infeed data of shape F32[3].
- ASSERT_IS_OK(client_->TransferToInfeed(*Literal::CreateR1<float>({1, 2, 3})));
- ASSERT_IS_OK(client_->TransferToInfeed(*Literal::CreateR1<float>({4, 5, 6})));
- ASSERT_IS_OK(client_->TransferToInfeed(*Literal::CreateR1<float>({7, 8, 9})));
ASSERT_IS_OK(
- client_->TransferToInfeed(*Literal::CreateR1<float>({10, 11, 12})));
+ client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({1, 2, 3})));
+ ASSERT_IS_OK(
+ client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({4, 5, 6})));
+ ASSERT_IS_OK(
+ client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({7, 8, 9})));
+ ASSERT_IS_OK(
+ client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({10, 11, 12})));
ASSERT_IS_OK(
- client_->TransferToInfeed(*Literal::CreateR1<float>({13, 14, 15})));
+ client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({13, 14, 15})));
delete computation_thread; // Joins the thread.
auto result_literal = client_->Transfer(*result).ConsumeValueOrDie();
@@ -247,17 +250,17 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) {
// Send the first 4 Infeed data of shape Tuple(F32[2], PRED).
ASSERT_IS_OK(client_->TransferToInfeed(
- *Literal::MakeTuple({Literal::CreateR1<float>({1, 2}).get(),
- Literal::CreateR0<bool>(true).get()})));
+ *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1, 2}).get(),
+ LiteralUtil::CreateR0<bool>(true).get()})));
ASSERT_IS_OK(client_->TransferToInfeed(
- *Literal::MakeTuple({Literal::CreateR1<float>({3, 4}).get(),
- Literal::CreateR0<bool>(true).get()})));
+ *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({3, 4}).get(),
+ LiteralUtil::CreateR0<bool>(true).get()})));
ASSERT_IS_OK(client_->TransferToInfeed(
- *Literal::MakeTuple({Literal::CreateR1<float>({5, 6}).get(),
- Literal::CreateR0<bool>(true).get()})));
+ *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({5, 6}).get(),
+ LiteralUtil::CreateR0<bool>(true).get()})));
ASSERT_IS_OK(client_->TransferToInfeed(
- *Literal::MakeTuple({Literal::CreateR1<float>({7, 8}).get(),
- Literal::CreateR0<bool>(false).get()})));
+ *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({7, 8}).get(),
+ LiteralUtil::CreateR0<bool>(false).get()})));
// Asynchronously launch the execution on the device.
std::unique_ptr<GlobalData> result;
@@ -272,14 +275,14 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) {
// Infeed data, and send the rest Infeed data of shape Tuple(F32[3], PRED).
sleep(1);
ASSERT_IS_OK(client_->TransferToInfeed(
- *Literal::MakeTuple({Literal::CreateR1<float>({1, 2, 3}).get(),
- Literal::CreateR0<bool>(true).get()})));
+ *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1, 2, 3}).get(),
+ LiteralUtil::CreateR0<bool>(true).get()})));
ASSERT_IS_OK(client_->TransferToInfeed(
- *Literal::MakeTuple({Literal::CreateR1<float>({7, 8, 9}).get(),
- Literal::CreateR0<bool>(false).get()})));
+ *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({7, 8, 9}).get(),
+ LiteralUtil::CreateR0<bool>(false).get()})));
ASSERT_IS_OK(client_->TransferToInfeed(
- *Literal::MakeTuple({Literal::CreateR1<float>({4, 5, 6}).get(),
- Literal::CreateR0<bool>(true).get()})));
+ *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({4, 5, 6}).get(),
+ LiteralUtil::CreateR0<bool>(true).get()})));
// Wait for the execution to be done, and transfer the result.
delete computation_thread; // Joins the thread.
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc
index 3b6b0ed740..ccb61740f6 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <utility>
#include "llvm/IR/Module.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h"
@@ -42,7 +42,7 @@ TEST_F(CpuNoAliasTest, Concat) {
HloComputation::Builder builder(TestName());
std::unique_ptr<Literal> literal =
- Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
auto param_shape = ShapeUtil::MakeShape(F32, {2, 2});
HloInstruction* param_x = builder.AddInstruction(
HloInstruction::CreateParameter(0, param_shape, "x"));
diff --git a/tensorflow/compiler/xla/service/defuser_test.cc b/tensorflow/compiler/xla/service/defuser_test.cc
index 32b5c5d35f..e727ba49cb 100644
--- a/tensorflow/compiler/xla/service/defuser_test.cc
+++ b/tensorflow/compiler/xla/service/defuser_test.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/defuser.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
@@ -124,7 +124,7 @@ TEST_F(DefuserTest, NonTrivialFusionInstruction) {
auto div = builder.AddInstruction(
HloInstruction::CreateBinary(shape_, HloOpcode::kDivide, mul, param3));
auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
auto add2 = builder.AddInstruction(
HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, constant, div));
@@ -162,7 +162,7 @@ TEST_F(DefuserTest, MultipleFusionInstructions) {
auto div = builder.AddInstruction(
HloInstruction::CreateBinary(shape_, HloOpcode::kDivide, mul, param3));
auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
auto add2 = builder.AddInstruction(
HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, constant, div));
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
index 52aa53dcee..51f16bdc94 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
@@ -19,7 +19,7 @@ limitations under the License.
#include <type_traits>
#include <vector>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/types.h"
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
index ecd97a8796..0686ca74af 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
@@ -16,7 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/types.h"
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index 21c6f7d358..004a80d19d 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -1565,19 +1565,18 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicSlice(
// TODO(b/74360564): This is implementation defined behavior, but is
// currently respected by all implementations. Change this if we ever decide
- // to oficially document different behavior.
+ // to officially document different behavior.
start_index_value =
ir_builder_->CreateSExtOrTrunc(start_index_value, index_type);
- llvm::Value* operand_dim_size =
- index_typed_const(input_hlo->shape().dimensions(i));
- llvm::Value* output_dim_size =
- index_typed_const(hlo->shape().dimensions(i));
+ int64 largest_valid_start_index =
+ input_hlo->shape().dimensions(i) - hlo->shape().dimensions(i);
+ CHECK_GE(largest_valid_start_index, 0);
+ bool is_signed = ShapeUtil::ElementIsSigned(hlo->operand(1)->shape());
start_index_value = EmitIntegralMin(
- ir_builder_->CreateSub(operand_dim_size, output_dim_size),
- EmitIntegralMax(index_typed_const(0), start_index_value,
- /*is_signed=*/true),
- /*is_signed=*/true);
+ index_typed_const(largest_valid_start_index),
+ EmitIntegralMax(index_typed_const(0), start_index_value, is_signed),
+ is_signed);
start_index_value->setName(
AsStringRef(IrName(hlo, StrCat("start_idx", i))));
@@ -1610,19 +1609,22 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather(
llvm::Type* index_type = index.GetType();
// This is the index into `operand` that holds the element we want to
- // generate. This index "unsafe" as in the components in here may be
- // out of bounds.
- IrArray::Index unsafe_operand_index(index_type);
-
- // First copy in the window indices to unsafe_operand_index.
- for (int64 i = 0, e = operand_shape.dimensions_size(),
- unsafe_operand_index_dim = 0;
+ // generate.
+ IrArray::Index operand_index(index_type);
+
+ // First copy in the window indices to operand_index. Also collect a mapping
+ // from operand dimension to output window dimension. Elided window dimensions
+ // map to -1.
+ std::vector<int64> operand_to_output_dim(operand_shape.dimensions_size(), -1);
+ for (int64 i = 0, e = operand_shape.dimensions_size(), operand_index_dim = 0;
i < e; i++) {
if (c_binary_search(dim_numbers.elided_window_dims(), i)) {
- unsafe_operand_index.push_back(index.GetConstantWithIndexType(0));
+ operand_index.push_back(index.GetConstantWithIndexType(0));
} else {
- unsafe_operand_index.push_back(
- index[dim_numbers.output_window_dims(unsafe_operand_index_dim++)]);
+ int64 output_window_dim =
+ dim_numbers.output_window_dims(operand_index_dim++);
+ operand_to_output_dim[i] = output_window_dim;
+ operand_index.push_back(index[output_window_dim]);
}
}
@@ -1641,20 +1643,40 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather(
}
}
- auto add_to_unsafe_operand_index = [&](llvm::Value* index_component,
- int64 dim) {
+ auto add_to_operand_index = [&](llvm::Value* index_component, int64 dim) {
llvm::Value* gather_dim_component_extended =
ir_builder_->CreateSExtOrTrunc(index_component, index_type);
- unsafe_operand_index[dim_numbers.gather_dims_to_operand_dims(dim)] =
- ir_builder_->CreateAdd(
- unsafe_operand_index[dim_numbers.gather_dims_to_operand_dims(dim)],
- gather_dim_component_extended);
+ int64 operand_dim = dim_numbers.gather_dims_to_operand_dims(dim);
+ int64 output_dim = operand_to_output_dim[operand_dim];
+ // If 'output_dim' is -1, it means 'operand_dim' is an elided window dim.
+ // This means we set the iteration index to 0, so for the purpose of the
+ // following calculations we can consider the output dimension size to be 1.
+ int64 output_dim_size =
+ output_dim == -1 ? 1 : output_shape.dimensions(output_dim);
+ int64 largest_valid_start_index =
+ operand_shape.dimensions(operand_dim) - output_dim_size;
+ CHECK_GE(largest_valid_start_index, 0);
+
+ // Clamp the gather index so that the gather region fits in the operand.
+ // gather_dim_component_extended_inbound =
+ // clamp(gather_dim_component_extended, 0, largest_valid_start_index);
+
+ // TODO(b/111078873): This is implementation defined behavior.
+ bool is_signed = ShapeUtil::ElementIsSigned(indices_shape);
+ auto gather_dim_component_extended_inbound = EmitIntegralMin(
+ index.GetConstantWithIndexType(largest_valid_start_index),
+ EmitIntegralMax(index.GetConstantWithIndexType(0),
+ gather_dim_component_extended, is_signed),
+ is_signed);
+
+ operand_index[operand_dim] = ir_builder_->CreateAdd(
+ operand_index[operand_dim], gather_dim_component_extended_inbound);
};
if (indices_shape.dimensions_size() == dim_numbers.index_vector_dim()) {
TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component,
indices_generator(gather_index_index));
- add_to_unsafe_operand_index(gather_dim_component, 0);
+ add_to_operand_index(gather_dim_component, 0);
} else {
int64 index_vector_size =
indices_shape.dimensions(dim_numbers.index_vector_dim());
@@ -1663,18 +1685,10 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather(
index.GetConstantWithIndexType(i);
TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component,
indices_generator(gather_index_index));
- add_to_unsafe_operand_index(gather_dim_component, i);
+ add_to_operand_index(gather_dim_component, i);
}
}
-
- IrArray::Index safe_operand_index(index_type);
- for (int64 i = 0, e = unsafe_operand_index.size(); i < e; i++) {
- safe_operand_index.push_back(ir_builder_->CreateURem(
- unsafe_operand_index[i],
- index.GetConstantWithIndexType(operand_shape.dimensions(i))));
- }
-
- return operand_generator(safe_operand_index);
+ return operand_generator(operand_index);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice(
@@ -1706,19 +1720,20 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice(
// TODO(b/74360564): This is implementation defined behavior, but is
// currently respected by all implementations. Change this if we ever decide
- // to oficially document different behavior.
+ // to officially document different behavior.
start_index_value =
ir_builder_->CreateSExtOrTrunc(start_index_value, index_type);
- llvm::Value* input_dim_size =
- index_typed_const(input_hlo->shape().dimensions(i));
llvm::Value* update_dim_size =
index_typed_const(update_hlo->shape().dimensions(i));
+ int64 largest_valid_start_index =
+ input_hlo->shape().dimensions(i) - update_hlo->shape().dimensions(i);
+ CHECK_GE(largest_valid_start_index, 0);
- start_index_value =
- EmitIntegralMin(ir_builder_->CreateSub(input_dim_size, update_dim_size),
- EmitIntegralMax(index_typed_const(0), start_index_value,
- /*is_signed=*/true),
- /*is_signed=*/true);
+ bool is_signed = ShapeUtil::ElementIsSigned(start_hlo->shape());
+ start_index_value = EmitIntegralMin(
+ index_typed_const(largest_valid_start_index),
+ EmitIntegralMax(index_typed_const(0), start_index_value, is_signed),
+ is_signed);
start_index_value->setName(
AsStringRef(IrName(hlo, StrCat("start_idx", i))));
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc
index 8980d43033..addb016b04 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc
@@ -57,8 +57,8 @@ ENTRY main {
}
)";
- std::unique_ptr<Literal> lhs = Literal::CreateR3<int32>({{{1}, {2}}});
- std::unique_ptr<Literal> rhs = Literal::CreateR3<int32>({{{3}, {4}}});
+ std::unique_ptr<Literal> lhs = LiteralUtil::CreateR3<int32>({{{1}, {2}}});
+ std::unique_ptr<Literal> rhs = LiteralUtil::CreateR3<int32>({{{3}, {4}}});
RunTest(hlo_text, {lhs.get(), rhs.get()});
}
} // namespace
diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc
index fd75847d0c..7cf2746947 100644
--- a/tensorflow/compiler/xla/service/executable.cc
+++ b/tensorflow/compiler/xla/service/executable.cc
@@ -82,18 +82,7 @@ StatusOr<ScopedShapedBuffer> Executable::ExecuteOnStreamWrapper(
StatusOr<ScopedShapedBuffer> return_value =
ExecuteOnStream(run_options, arguments, profile_ptr.get());
- if (!return_value.status().ok()) {
- if (profile != nullptr) {
- // Ensure the ThenStartTimer call has completed before we destroy timer.
- // We already have a failure status to return, so just log this if it
- // fails.
- Status status = stream->BlockHostUntilDone();
- if (!status.ok()) {
- LOG(ERROR) << "Failed to BlockHostUntilDone: " << status;
- }
- }
- return return_value.status();
- }
+ TF_RETURN_IF_ERROR(return_value.status());
if (profile != nullptr) {
VLOG(1) << "enqueueing 'stop timer' and blocking host until done...";
diff --git a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc
index d3854b40de..8f6608241e 100644
--- a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc
+++ b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/call_graph.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -80,7 +80,7 @@ class FlattenCallGraphTest : public HloTestBase {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, kScalarShape, "param0"));
HloInstruction* zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, param0, zero));
return builder.Build();
@@ -157,7 +157,7 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) {
builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(PRED, {}), "param0"));
HloInstruction* false_constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
builder.AddInstruction(
HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}),
HloOpcode::kEq, param0, false_constant));
@@ -168,7 +168,7 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) {
{
HloComputation::Builder builder(TestName() + ".entry");
HloInstruction* false_constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
builder.AddInstruction(HloInstruction::CreateWhile(
ShapeUtil::MakeShape(PRED, {}), cond_computation, cond_computation,
false_constant));
@@ -232,11 +232,11 @@ TEST_F(FlattenCallGraphTest, FlattenCallsInConditional) {
// computation in the true and false branch.
HloComputation::Builder builder(TestName());
auto pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(56.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(56.0f)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(12.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(12.0f)));
builder.AddInstruction(HloInstruction::CreateConditional(
kScalarShape, pred, constant1, sub_computation, constant2,
sub_computation));
diff --git a/tensorflow/compiler/xla/service/gather_expander.cc b/tensorflow/compiler/xla/service/gather_expander.cc
index 7cd2c9c136..e3a42d0d06 100644
--- a/tensorflow/compiler/xla/service/gather_expander.cc
+++ b/tensorflow/compiler/xla/service/gather_expander.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include <utility>
+#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/gather_expander.h"
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -113,7 +114,7 @@ static StatusOr<HloInstruction*> ExpandIndexVectorIntoOperandSpace(
const Shape& index_shape = index_vector->shape();
HloInstruction* zero =
computation->AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateFromDimensions(index_shape.element_type(), {1})));
+ LiteralUtil::CreateFromDimensions(index_shape.element_type(), {1})));
// We extract out individual components from the smaller index and concatenate
// them (interspersing zeros as needed) into the larger index.
diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
index 85e28a0dfe..e314a469f0 100644
--- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/interpreter/platform_id.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -158,16 +158,10 @@ Status GenericTransferManager::TransferLiteralToInfeed(
return Unimplemented("Generic transfer to Infeed");
}
-Status GenericTransferManager::TransferBufferToInfeed(
- se::StreamExecutor* executor, int64 size, const void* source) {
- return Unimplemented("Generic transfer to Infeed");
-}
-
Status GenericTransferManager::TransferLiteralFromOutfeed(
se::StreamExecutor* executor, const Shape& literal_shape,
Literal* literal) {
- return Unimplemented(
- "Outfeed is not supported on this platform (b/30467474)");
+ return Unimplemented("Generic transfer from Outfeed");
}
Status GenericTransferManager::ResetDevices(
diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h
index d216fe7d29..3cd002c1bf 100644
--- a/tensorflow/compiler/xla/service/generic_transfer_manager.h
+++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h
@@ -61,9 +61,6 @@ class GenericTransferManager : public TransferManager {
int64 GetByteSizeRequirement(const Shape& shape) const override;
protected:
- Status TransferBufferToInfeed(se::StreamExecutor* executor, int64 size,
- const void* source) override;
-
Status WriteSingleTupleIndexTable(
se::Stream* stream,
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> elements,
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index d90b0fb57d..59172e53d3 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -150,7 +150,7 @@ cc_library(
":parallel_loop_emitter",
":partition_assignment",
":while_transformer",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
@@ -165,6 +165,7 @@ cc_library(
"//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter",
"//tensorflow/compiler/xla/service/llvm_ir:ir_array",
"//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library",
+ "//tensorflow/compiler/xla/service/llvm_ir:kernel_tiling",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_loop",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/compiler/xla/service/llvm_ir:loop_emitter",
@@ -199,7 +200,7 @@ cc_library(
srcs = ["elemental_ir_emitter.cc"],
hdrs = ["elemental_ir_emitter.h"],
deps = [
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
@@ -265,6 +266,7 @@ cc_library(
"infeed_thunk.cc",
"kernel_thunk.cc",
"memset_thunk.cc",
+ "outfeed_thunk.cc",
"sequential_thunk.cc",
"thunk_schedule.cc",
"tuple_thunk.cc",
@@ -282,6 +284,7 @@ cc_library(
"infeed_thunk.h",
"kernel_thunk.h",
"memset_thunk.h",
+ "outfeed_thunk.h",
"sequential_thunk.h",
"thunk.h",
"thunk_schedule.h",
@@ -289,15 +292,16 @@ cc_library(
"while_thunk.h",
],
deps = [
- ":backend_configs",
":buffer_allocations",
":cudnn_convolution_runner",
":hlo_execution_profiler",
":infeed_manager",
":ir_emission_utils",
+ ":outfeed_manager",
":partition_assignment",
":stream_assignment",
"//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_tree",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status",
@@ -351,6 +355,7 @@ cc_library(
":cudnn_convolution_runner",
":gpu_executable",
":ir_emission_utils",
+ "//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_pass",
@@ -382,7 +387,7 @@ cc_library(
hdrs = ["cudnn_convolution_rewriter.h"],
deps = [
":ir_emission_utils",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla:xla_data_proto",
@@ -517,6 +522,7 @@ cc_library(
hdrs = ["pad_insertion.h"],
deps = [
":ir_emission_utils",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:window_util",
@@ -533,7 +539,10 @@ cc_library(
hdrs = ["gpu_transfer_manager.h"],
deps = [
":gpu_compiler",
+ ":outfeed_manager",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_tree",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
@@ -624,6 +633,7 @@ cc_library(
hdrs = ["cudnn_batchnorm_rewriter.h"],
deps = [
":ir_emission_utils",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_pass",
@@ -631,18 +641,39 @@ cc_library(
)
cc_library(
+ name = "xfeed_queue",
+ hdrs = ["xfeed_queue.h"],
+ deps = ["//tensorflow/core:lib"],
+)
+
+cc_library(
name = "infeed_manager",
srcs = ["infeed_manager.cc"],
hdrs = ["infeed_manager.h"],
deps = [
+ ":xfeed_queue",
+ "//tensorflow/compiler/xla:shape_tree",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
- "//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
],
)
cc_library(
+ name = "outfeed_manager",
+ srcs = ["outfeed_manager.cc"],
+ hdrs = ["outfeed_manager.h"],
+ deps = [
+ ":xfeed_queue",
+ "//tensorflow/compiler/xla:literal",
+ "//tensorflow/compiler/xla:shape_tree",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
name = "gpu_layout_assignment",
srcs = ["gpu_layout_assignment.cc"],
hdrs = ["gpu_layout_assignment.h"],
@@ -716,7 +747,7 @@ cc_library(
srcs = ["while_transformer.cc"],
hdrs = ["while_transformer.h"],
deps = [
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc
index 5e4fe1dd39..5780e0af40 100644
--- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc
@@ -33,8 +33,11 @@ ConditionalThunk::ConditionalThunk(
predicate_buffer_index_(predicate_buffer_index),
true_operand_buffer_index_(true_operand_buffer_index),
false_operand_buffer_index_(false_operand_buffer_index),
- true_thunk_(std::move(true_thunk_sequence), hlo),
- false_thunk_(std::move(false_thunk_sequence), hlo) {}
+ // Pass nullptr as the HloInstruction* to the true_thunk_ and false_thunk_
+ // constructors because these SequentialThunks are logically "part of"
+ // this ConditionalThunk, and shouldn't be profiled separately from it.
+ true_thunk_(std::move(true_thunk_sequence), nullptr),
+ false_thunk_(std::move(false_thunk_sequence), nullptr) {}
Status ConditionalThunk::Initialize(const GpuExecutable& executable,
se::StreamExecutor* executor) {
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc
index c77e3c81c9..6028950652 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
@@ -66,11 +67,12 @@ Status Visitor::HandleBatchNormInference(HloInstruction* batch_norm) {
return Status::OK();
}
- HloInstruction* epsilon = computation_->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon())));
+ HloInstruction* epsilon =
+ computation_->AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR0(batch_norm->epsilon())));
HloInstruction* feature_index =
computation_->AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR0(batch_norm->feature_index())));
+ LiteralUtil::CreateR0(batch_norm->feature_index())));
std::vector<HloInstruction*> operands(batch_norm->operands().begin(),
batch_norm->operands().end());
@@ -101,11 +103,12 @@ Status Visitor::HandleBatchNormTraining(HloInstruction* batch_norm) {
return Status::OK();
}
- HloInstruction* epsilon = computation_->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon())));
+ HloInstruction* epsilon =
+ computation_->AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR0(batch_norm->epsilon())));
HloInstruction* feature_index =
computation_->AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR0(batch_norm->feature_index())));
+ LiteralUtil::CreateR0(batch_norm->feature_index())));
std::vector<HloInstruction*> operands(batch_norm->operands().begin(),
batch_norm->operands().end());
@@ -128,8 +131,8 @@ Status Visitor::HandleBatchNormTraining(HloInstruction* batch_norm) {
inverse_stddev->shape(), HloOpcode::kPower, inverse_stddev,
computation_->AddInstruction(HloInstruction::CreateBroadcast(
inverse_stddev->shape(),
- computation_->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(-2))),
+ computation_->AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR0<float>(-2))),
{}))));
HloInstruction* variance =
computation_->AddInstruction(HloInstruction::CreateBinary(
@@ -169,11 +172,12 @@ Status Visitor::HandleBatchNormGrad(HloInstruction* batch_norm) {
return Status::OK();
}
- HloInstruction* epsilon = computation_->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon())));
+ HloInstruction* epsilon =
+ computation_->AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR0(batch_norm->epsilon())));
HloInstruction* feature_index =
computation_->AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR0(batch_norm->feature_index())));
+ LiteralUtil::CreateR0(batch_norm->feature_index())));
// The cudnn libcall expects its input to be rsqrt(variance + epsilon), but
// the batchnorm HLO takes plain variance as input. Fix it up.
@@ -189,7 +193,7 @@ Status Visitor::HandleBatchNormGrad(HloInstruction* batch_norm) {
computation_->AddInstruction(HloInstruction::CreateBroadcast(
var_plus_epsilon->shape(),
computation_->AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR0<float>(-.5))),
+ LiteralUtil::CreateR0<float>(-.5))),
{}))));
std::vector<HloInstruction*> operands(batch_norm->operands().begin(),
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
index 3dc98c4c93..5a63e65208 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h"
+#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
#include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
@@ -80,8 +81,7 @@ bool ShouldIncludeWinogradNonfusedAlgo(const Shape& input_shape,
const ConvolutionDimensionNumbers& dnums,
se::StreamExecutor* stream_exec) {
// Skip this check for cudnn7 and newer.
- auto version =
- stream_exec->AsDnn()->GetVersion();
+ auto version = stream_exec->AsDnn()->GetVersion();
if (version.ok() && version.ValueOrDie().major_version() >= 7) {
return true;
}
@@ -338,8 +338,8 @@ StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction(
computation->AddInstruction(HloInstruction::CreateTuple(
{computation->AddInstruction(HloInstruction::CreateGetTupleElement(
new_call_shape.tuple_shapes(0), new_call, 0)),
- computation->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<uint8>({})))}));
+ computation->AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<uint8>({})))}));
TF_RETURN_IF_ERROR(instr->parent()->ReplaceInstruction(instr, new_tuple));
return true;
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
index f9dccd287d..905b5ee876 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include <numeric>
#include <vector>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
index 27d2c3e491..e594cec2f8 100644
--- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
@@ -29,7 +29,7 @@ limitations under the License.
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.cc b/tensorflow/compiler/xla/service/gpu/for_thunk.cc
index 4fdc55909a..b3a3c5dcb4 100644
--- a/tensorflow/compiler/xla/service/gpu/for_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/for_thunk.cc
@@ -28,8 +28,11 @@ ForThunk::ForThunk(const int64 loop_limit,
const HloInstruction* hlo)
: Thunk(Kind::kWhile, hlo),
loop_limit_(loop_limit),
- body_thunk_sequence_(
- MakeUnique<SequentialThunk>(std::move(*body_thunk_sequence), hlo)) {}
+ body_thunk_sequence_(MakeUnique<SequentialThunk>(
+ // Pass nullptr as the HloInstruction* to the body_thunk_sequence_
+ // constructor because this SequentialThunk is logically "part of"
+ // this ForThunk, and shouldn't be profiled separately from it.
+ std::move(*body_thunk_sequence), nullptr)) {}
Status ForThunk::Initialize(const GpuExecutable& executable,
se::StreamExecutor* executor) {
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
index decfc40daf..e1da8d940c 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
@@ -552,8 +552,7 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
&ir_emitter_context);
{
XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - IR emission");
- TF_RETURN_IF_ERROR(
- entry_computation->root_instruction()->Accept(&ir_emitter));
+ TF_RETURN_IF_ERROR(entry_computation->Accept(&ir_emitter));
}
if (user_pre_optimization_hook_) {
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
index e48165c142..95f78ae293 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
@@ -132,10 +132,10 @@ TEST_F(LayoutAssignmentTest, BatchNormInference) {
HloInstruction::CreateParameter(4, aux_shape, "variance"));
auto* epsilon = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1)));
auto* feature_index =
builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR0<int64>(kFeatureIndex)));
+ LiteralUtil::CreateR0<int64>(kFeatureIndex)));
auto* batchnorm = builder.AddInstruction(HloInstruction::CreateCustomCall(
shape,
@@ -201,10 +201,10 @@ TEST_F(LayoutAssignmentTest, BatchNormTraining) {
HloInstruction::CreateParameter(2, offset_scale_shape, "offset"));
auto* epsilon = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1)));
auto* feature_index =
builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR0<int64>(kFeatureIndex)));
+ LiteralUtil::CreateR0<int64>(kFeatureIndex)));
auto* batchnorm = builder.AddInstruction(HloInstruction::CreateCustomCall(
batchnorm_shape, {operand, scale, offset, epsilon, feature_index},
@@ -278,10 +278,10 @@ TEST_F(LayoutAssignmentTest, BatchNormGrad) {
HloInstruction::CreateParameter(4, shape, "var"));
auto* epsilon = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1)));
auto* feature_index =
builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR0<int64>(kFeatureIndex)));
+ LiteralUtil::CreateR0<int64>(kFeatureIndex)));
auto* batchnorm =
builder.AddInstruction(HloInstruction::CreateCustomCall(
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc
index 5343497c03..1446401b19 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc
@@ -20,8 +20,10 @@ limitations under the License.
#include <vector>
#include "llvm/IR/DataLayout.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h"
+#include "tensorflow/compiler/xla/service/gpu/outfeed_manager.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -34,6 +36,7 @@ limitations under the License.
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace xla {
+namespace gpu {
// TODO(b/30467474) Once GPU infeed implementation settles, consider
// folding back the cpu and gpu infeed implementations into a generic
@@ -50,48 +53,28 @@ Status GpuTransferManager::TransferLiteralToInfeed(
VLOG(2) << "Transferring literal to infeed with shape: "
<< ShapeUtil::HumanString(shape);
- if (!ShapeUtil::IsTuple(shape)) {
- int64 size = GetByteSizeRequirement(shape);
- return TransferBufferToInfeed(executor, size, literal.untyped_data());
- }
-
// For a tuple, we transfer each of its elements to the device and
// enqueue the resulting destination device addresses with the
// infeed manager.
- std::vector<gpu::InfeedBuffer*> buffers;
- auto cleanup = tensorflow::gtl::MakeCleanup([buffers]() {
- for (gpu::InfeedBuffer* b : buffers) {
- b->Done();
- }
- });
+ ShapeTree<InfeedBuffer> buffer_tree(shape);
TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
shape, [&](const Shape& literal_subshape, const ShapeIndex& index) {
if (ShapeUtil::IsArray(literal_subshape)) {
int64 tuple_element_size = GetByteSizeRequirement(literal_subshape);
TF_ASSIGN_OR_RETURN(
- gpu::InfeedBuffer * buffer,
+ *buffer_tree.mutable_element(index),
TransferBufferToInfeedInternal(executor, tuple_element_size,
literal.untyped_data(index)));
- buffers.push_back(buffer);
}
return Status::OK();
}));
- cleanup.release();
- return EnqueueBuffersToInfeed(executor, buffers);
-}
-
-Status GpuTransferManager::TransferBufferToInfeed(se::StreamExecutor* executor,
- int64 size,
- const void* source) {
- TF_ASSIGN_OR_RETURN(gpu::InfeedBuffer * buffer,
- TransferBufferToInfeedInternal(executor, size, source));
- return EnqueueBuffersToInfeed(executor, {buffer});
+ return EnqueueBuffersToInfeed(executor, std::move(buffer_tree));
}
Status GpuTransferManager::EnqueueBuffersToInfeed(
- se::StreamExecutor* executor, std::vector<gpu::InfeedBuffer*> buffers) {
+ se::StreamExecutor* executor, ShapeTree<InfeedBuffer> buffers) {
gpu::InfeedManager* infeed_manager = gpu::GetOrCreateInfeedManager();
se::Stream* stream = infeed_manager->GetStream(executor);
@@ -101,21 +84,18 @@ Status GpuTransferManager::EnqueueBuffersToInfeed(
// possible.
Status block_status = stream->BlockHostUntilDone();
if (!block_status.ok()) {
- for (gpu::InfeedBuffer* b : buffers) {
- b->Done();
- }
return InternalError("Failed to complete data transfer on stream %p: %s",
stream, block_status.error_message().c_str());
}
- infeed_manager->EnqueueBuffers(buffers);
+ infeed_manager->EnqueueDestination(std::move(buffers));
VLOG(2) << "Infeed data transferred";
return Status::OK();
}
-StatusOr<gpu::InfeedBuffer*> GpuTransferManager::TransferBufferToInfeedInternal(
+StatusOr<InfeedBuffer> GpuTransferManager::TransferBufferToInfeedInternal(
se::StreamExecutor* executor, int64 size, const void* source) {
if (size > std::numeric_limits<int32>::max()) {
return InvalidArgument("Infeed shape is too large: needs %lld bytes", size);
@@ -131,18 +111,76 @@ StatusOr<gpu::InfeedBuffer*> GpuTransferManager::TransferBufferToInfeedInternal(
return InternalError("Failed to obtain a stream");
}
- gpu::InfeedBuffer* buffer = new gpu::InfeedBuffer(executor, size);
- stream->ThenMemcpy(buffer->device_memory(), source, size);
+ InfeedBuffer buffer(executor, size);
+ stream->ThenMemcpy(buffer.device_memory(), source, size);
VLOG(2) << "Queued infeed data on stream " << stream;
- return buffer;
+ return std::move(buffer);
+}
+
+static std::unique_ptr<Literal> ShapeTreeToLiteral(
+ ShapeTree<std::unique_ptr<gpu::OutfeedBuffer>>* shape_tree) {
+ // This is a struct instead of a lambda for std::function-free recursion.
+ struct Helper {
+ static std::unique_ptr<Literal> helper(
+ ShapeTree<std::unique_ptr<gpu::OutfeedBuffer>>* shape_tree,
+ ShapeIndex* index) {
+ const Shape& shape = ShapeUtil::GetSubshape(shape_tree->shape(), *index);
+ if (ShapeUtil::IsArray(shape)) {
+ return (*shape_tree->mutable_element(*index))->WaitUntilAvailable();
+ }
+
+ CHECK(ShapeUtil::IsTuple(shape))
+ << ShapeUtil::HumanStringWithLayout(shape);
+ const int64 tuple_element_count = ShapeUtil::TupleElementCount(shape);
+ index->push_back(0);
+ std::vector<std::unique_ptr<Literal>> tuple_operands;
+ for (int64 i = 0; i < tuple_element_count; ++i) {
+ index->back() = i;
+ tuple_operands.push_back(helper(shape_tree, index));
+ }
+ index->pop_back();
+ return LiteralUtil::MakeTupleOwned(std::move(tuple_operands));
+ }
+ };
+ ShapeIndex index;
+ return Helper::helper(shape_tree, &index);
+}
+
+Status GpuTransferManager::TransferLiteralFromOutfeed(
+ se::StreamExecutor* /*executor*/, const Shape& literal_shape,
+ Literal* literal) {
+ ShapeTree<std::unique_ptr<gpu::OutfeedBuffer>> outfeed_buffers(
+ &literal_shape);
+
+ // First create a tree of literal buffers that the device can write to.
+ outfeed_buffers.ForEachMutableElement(
+ [&](const ShapeIndex& index,
+ std::unique_ptr<gpu::OutfeedBuffer>* buffer) {
+ const Shape& shape = ShapeUtil::GetSubshape(literal_shape, index);
+ // Do not transfer tuple index buffers.
+ if (ShapeUtil::IsTuple(shape)) {
+ return;
+ }
+ *buffer = MakeUnique<gpu::OutfeedBuffer>(GetByteSizeRequirement(shape));
+ });
+
+ // Give the tree of buffers to the outfeed mananger. The device will fill it
+ // while we're waiting for it below.
+ gpu::OutfeedManager* outfeed_manager = gpu::GetOrCreateOutfeedManager();
+ outfeed_manager->EnqueueDestination(&outfeed_buffers);
+
+ // Now turn the tree of buffers back into a literal.
+ *literal = std::move(*ShapeTreeToLiteral(&outfeed_buffers));
+ return Status::OK();
}
+} // namespace gpu
} // namespace xla
static std::unique_ptr<xla::TransferManager> CreateGpuTransferManager() {
- return xla::MakeUnique<xla::GpuTransferManager>();
+ return xla::MakeUnique<xla::gpu::GpuTransferManager>();
}
static bool InitModule() {
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h
index 09f8227f50..8122c9d8c3 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/generic_transfer_manager.h"
#include "tensorflow/compiler/xla/service/gpu/infeed_manager.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
+#include "tensorflow/compiler/xla/shape_tree.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/macros.h"
@@ -28,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
namespace xla {
+namespace gpu {
// An implementation of the XLA GenericTransferManager that
// handles GPU-specific infeed.
@@ -38,23 +40,25 @@ class GpuTransferManager : public GenericTransferManager {
Status TransferLiteralToInfeed(se::StreamExecutor* executor,
const LiteralSlice& literal) override;
- Status TransferBufferToInfeed(se::StreamExecutor* executor, int64 size,
- const void* source) override;
+ Status TransferLiteralFromOutfeed(se::StreamExecutor* executor,
+ const Shape& literal_shape,
+ Literal* literal) override;
private:
// Initiates the infeed data transfers. InfeedBuffer->Done() must be
// called to clean up the memory allocated for InfeedBuffer.
- StatusOr<gpu::InfeedBuffer*> TransferBufferToInfeedInternal(
+ StatusOr<InfeedBuffer> TransferBufferToInfeedInternal(
se::StreamExecutor* executor, int64 size, const void* source);
// Enqueues infeed data buffers with the infeed manager after their
// transfer completes.
Status EnqueueBuffersToInfeed(se::StreamExecutor* executor,
- std::vector<gpu::InfeedBuffer*> buffers);
+ ShapeTree<InfeedBuffer> buffers);
TF_DISALLOW_COPY_AND_ASSIGN(GpuTransferManager);
};
+} // namespace gpu
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TRANSFER_MANAGER_H_
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc
index 3e96beb575..19420e590d 100644
--- a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <memory>
#include <stack>
+#include <unordered_set>
#include <vector>
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -99,6 +100,7 @@ void HloExecutionProfiler::StartHloInstruction() {
void HloExecutionProfiler::FinishHloInstruction(
const HloInstruction* hlo_instruction) {
if (do_profile_) {
+ hlo_instructions_.erase(hlo_instruction);
profile_->SetCyclesTakenBy(
hlo_instruction,
GetCyclesTaken(&timers_, sub_streams_, stream_, clock_rate_ghz_));
@@ -108,6 +110,12 @@ void HloExecutionProfiler::FinishHloInstruction(
std::unique_ptr<ScopedInstructionProfiler>
HloExecutionProfiler::MakeScopedInstructionProfiler(
const HloInstruction* hlo_instruction) {
+ if (do_profile_ && hlo_instruction != nullptr) {
+ // Make sure that we are not already measuring the time for the same
+ // 'hlo_instruction'.
+ CHECK(hlo_instructions_.insert(hlo_instruction).second)
+ << hlo_instruction->name();
+ }
return MakeUnique<ScopedInstructionProfiler>(this, hlo_instruction);
}
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h
index e5c655edc6..6654850bef 100644
--- a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h
+++ b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
#include <stack>
+#include <unordered_set>
#include <vector>
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -74,6 +75,9 @@ class HloExecutionProfiler {
const std::vector<Pool<se::Stream>::SmartPtr>& sub_streams_;
const HloComputation* computation_;
std::stack<std::unique_ptr<se::Timer>> timers_;
+ // Contains the HLO instructions for which we are currently measuring the
+ // time.
+ std::unordered_set<const HloInstruction*> hlo_instructions_;
bool finished_execution_ = false;
};
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc
index 375709150e..19de37b0fb 100644
--- a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc
+++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc
@@ -100,7 +100,7 @@ GpuHloOrdering::GpuHloOrdering(
if (last_instruction_per_stream[stream_no] != nullptr) {
immediate_preds.push_back(last_instruction_per_stream[stream_no]);
}
- predecessor_map->SetReachabilityToUnion(immediate_preds, hlo);
+ predecessor_map->FastSetReachabilityToUnion(immediate_preds, hlo);
last_instruction_per_stream[stream_no] = hlo;
} else {
// Only parameters and constants don't have an assigned stream, since they
diff --git a/tensorflow/compiler/xla/service/gpu/infeed_manager.cc b/tensorflow/compiler/xla/service/gpu/infeed_manager.cc
index ae310beefa..c5f0cdf6cd 100644
--- a/tensorflow/compiler/xla/service/gpu/infeed_manager.cc
+++ b/tensorflow/compiler/xla/service/gpu/infeed_manager.cc
@@ -15,76 +15,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/infeed_manager.h"
-#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
-#include "tensorflow/core/platform/logging.h"
namespace xla {
namespace gpu {
-InfeedManager::InfeedManager() : host_to_device_executor_(nullptr) {}
-
-void InfeedManager::Reset() {
- tensorflow::mutex_lock l(mu_);
- CHECK(dequeued_buffer_.empty());
- for (auto buffer : enqueued_buffer_) {
- buffer->Done();
- }
- enqueued_buffer_.clear();
-}
-
-void InfeedManager::EnqueueBuffers(const std::vector<InfeedBuffer*>& buffers) {
- tensorflow::mutex_lock l(mu_);
- bool was_empty = enqueued_buffer_.empty();
- for (gpu::InfeedBuffer* b : buffers) {
- enqueued_buffer_.push_back(b);
- }
- if (was_empty) {
- // This has the potential to suffer from the notified thread
- // immediately trying and failing to acquire mu_, but seems
- // preferable to the alternative of notifying outside the lock
- // on every enqueue.
- cv_.notify_one();
- }
-}
-
-InfeedBuffer* InfeedManager::BlockingDequeueBuffer() {
- bool became_empty = false;
- InfeedBuffer* current_buffer;
- {
- tensorflow::mutex_lock l(mu_);
- while (enqueued_buffer_.empty()) {
- cv_.wait(l);
- }
- current_buffer = enqueued_buffer_.front();
- enqueued_buffer_.pop_front();
- dequeued_buffer_.insert(current_buffer);
- if (enqueued_buffer_.empty()) {
- became_empty = true;
- }
- }
- if (became_empty) {
- for (const auto& callback : on_empty_callbacks_) {
- callback();
- }
- }
- return current_buffer;
-}
-
-void InfeedManager::ReleaseBuffers(const std::vector<InfeedBuffer*>& buffers) {
- {
- tensorflow::mutex_lock l(mu_);
- for (gpu::InfeedBuffer* b : buffers) {
- CHECK(ContainsKey(dequeued_buffer_, b));
- dequeued_buffer_.erase(b);
- }
- }
- for (gpu::InfeedBuffer* b : buffers) {
- b->Done();
- }
-}
-
se::Stream* InfeedManager::GetStream(se::StreamExecutor* executor) {
+ tensorflow::mutex_lock l(host_to_device_stream_mu_);
if (host_to_device_executor_ == nullptr) {
host_to_device_executor_ = executor;
host_to_device_stream_ = MakeUnique<se::Stream>(executor);
@@ -100,10 +37,6 @@ se::Stream* InfeedManager::GetStream(se::StreamExecutor* executor) {
return host_to_device_stream_.get();
}
-void InfeedManager::RegisterOnEmptyCallback(std::function<void()> callback) {
- on_empty_callbacks_.push_back(std::move(callback));
-}
-
InfeedManager* GetOrCreateInfeedManager() {
static InfeedManager* manager = new InfeedManager;
return manager;
diff --git a/tensorflow/compiler/xla/service/gpu/infeed_manager.h b/tensorflow/compiler/xla/service/gpu/infeed_manager.h
index a3fc15cfe3..7e418882e0 100644
--- a/tensorflow/compiler/xla/service/gpu/infeed_manager.h
+++ b/tensorflow/compiler/xla/service/gpu/infeed_manager.h
@@ -20,12 +20,9 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_INFEED_MANAGER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_INFEED_MANAGER_H_
-#include <deque>
-#include <vector>
-
+#include "tensorflow/compiler/xla/service/gpu/xfeed_queue.h"
+#include "tensorflow/compiler/xla/shape_tree.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
-#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace xla {
@@ -47,90 +44,41 @@ namespace gpu {
// the client. The client manages the memory of the buffer.
class InfeedBuffer {
public:
+ InfeedBuffer() = default;
InfeedBuffer(se::StreamExecutor* executor, int64 length)
- : executor_(executor), length_(length) {
- device_memory_ = executor_->AllocateArray<uint8>(length);
- CHECK(!device_memory_.is_null());
+ : device_memory_(executor, executor->AllocateArray<uint8>(length)),
+ length_(length) {
+ CHECK(!device_memory_->is_null());
}
- ~InfeedBuffer() { executor_->Deallocate(&device_memory_); }
-
int64 length() const { return length_; }
- // Callback to signal that this buffer is consumed. This helps the
- // client to manage memory for the infeed buffers.
- void Done() { delete this; }
-
- se::DeviceMemoryBase* device_memory() { return &device_memory_; }
+ se::DeviceMemoryBase* device_memory() { return device_memory_.ptr(); }
private:
- se::StreamExecutor* executor_; // Not owned.
- const int64 length_;
- se::DeviceMemoryBase device_memory_;
+ se::ScopedDeviceMemory<uint8> device_memory_;
+ int64 length_;
};
// Client-side class used to enqueue infeed buffers.
-class InfeedManager {
+class InfeedManager : public XfeedQueue<ShapeTree<InfeedBuffer>> {
public:
- InfeedManager();
-
- // Calls the completion callback for any enqueued buffers that have
- // not been dequeued by the runtime, and empties the infeed
- // queue. Reset may not be called while a runtime computation is
- // processing a dequeued buffer. The only safe way to ensure this
- // condition is to call Reset when no computation is taking place.
- void Reset();
-
- // Adds a set of buffers to the infeed queue atomically. buffer->Done
- // will be called when the buffer will no longer be accessed by the
- // InfeedManager, either as a result of a call to Reset or because the
- // runtime has dequeued and used the buffer.
- void EnqueueBuffers(const std::vector<InfeedBuffer*>& buffers);
-
- // Blocks until the infeed queue is non-empty, then returns the
- // buffer at the head of the queue. Adds the current buffer to the
- // to-be released set.
- InfeedBuffer* BlockingDequeueBuffer();
-
- // Releases a set of buffers from the to-be released set.
- void ReleaseBuffers(const std::vector<InfeedBuffer*>& buffers);
-
// Returns a cached stream associated with an executor. Allocates a
// new stream on the first invocation. On subsequent invocations, if
// the cached executor is not the same as the requested executor,
// returns null.
se::Stream* GetStream(se::StreamExecutor* executor);
- // Registers a callback that will be called when 'enqueued_buffer_' becomes
- // empty.
- void RegisterOnEmptyCallback(std::function<void()> callback);
-
private:
- // TODO(b/30467474): Revisit if this mutex becomes a point of
- // contention.
- tensorflow::mutex mu_;
-
- // Condition variable that is signaled every time a buffer is
- // enqueued to an empty queue.
- tensorflow::condition_variable cv_;
-
- // InfeedBuffer* queue contents are not owned, but buffer->Done must
- // be called when the buffer is no longer needed by the runtime.
- std::deque<InfeedBuffer*> enqueued_buffer_;
-
- // Buffers that are dequeued and currently being processed by the
- // runtime. Not owned.
- tensorflow::gtl::FlatSet<const InfeedBuffer*> dequeued_buffer_;
+ // Mutex for serializing the creation of host_to_device_stream_.
+ tensorflow::mutex host_to_device_stream_mu_;
// Cached host to device stream for queuing infeed data.
- std::unique_ptr<se::Stream> host_to_device_stream_;
+ std::unique_ptr<se::Stream> host_to_device_stream_
+ GUARDED_BY(host_to_device_stream_mu_);
// Executor that the host_to_device_stream belongs to. Not owned.
- se::StreamExecutor* host_to_device_executor_;
-
- // List of callbacks which will be called when 'enqueued_buffer_' becomes
- // empty.
- std::vector<std::function<void()>> on_empty_callbacks_;
+ se::StreamExecutor* host_to_device_executor_ = nullptr;
};
// Singleton creator-or-accessor: Returns the GPU infeed manager.
diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc
index 62915febb1..fee6d2af3b 100644
--- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc
@@ -30,51 +30,68 @@ InfeedThunk::InfeedThunk(
Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
se::Stream* stream,
HloExecutionProfiler* profiler) {
- VLOG(2) << "Infeeding to GPU ";
+ VLOG(2) << "Infeeding to GPU: " << hlo_instruction()->ToString();
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
- // First copy the infeed data which is element 0 of the infeed instruction's
- // two-tuple output (the other element is a token).
- se::DeviceMemoryBase data_address =
- buffer_allocations.GetDeviceAddress(infeed_slices_.element({0}));
- InfeedManager* infeed_manager = GetOrCreateInfeedManager();
- std::vector<InfeedBuffer*> infeed_buffers;
- const Shape& data_shape =
- ShapeUtil::GetTupleElementShape(hlo_instruction()->shape(), 0);
- if (ShapeUtil::IsTuple(data_shape)) {
- CHECK(!ShapeUtil::IsNestedTuple(data_shape));
- // Transfer the tuple elements first.
+ ShapeTree<InfeedBuffer> infeed_buffers =
+ GetOrCreateInfeedManager()->BlockingGetNextDestination();
+
+ {
+ // The infeed buffer has an extra outer tuple with a token. Adjust the index
+ // accordingly.
+ ShapeIndex index = {0};
+ std::function<void(std::vector<void*>*)> copy_tuple_contents =
+ [&](std::vector<void*>* tuple_element_addresses) {
+ const Shape& shape = ShapeUtil::GetSubshape(infeed_buffers.shape(),
+ ShapeIndexView(index, 1));
+ // For the leaf buffers of the tuple copy the elements directly.
+ if (ShapeUtil::IsArray(shape)) {
+ const BufferAllocation::Slice& tuple_element_buffer =
+ infeed_slices_.element(index);
+ se::DeviceMemoryBase tuple_element_address =
+ buffer_allocations.GetDeviceAddress(tuple_element_buffer);
+
+ InfeedBuffer* buffer =
+ infeed_buffers.mutable_element(ShapeIndexView(index, 1));
+ stream->ThenMemcpy(&tuple_element_address,
+ *(buffer->device_memory()), buffer->length());
+ tuple_element_addresses->push_back(tuple_element_address.opaque());
+ return;
+ }
+
+ const int64 tuple_element_count = ShapeUtil::TupleElementCount(shape);
+ index.push_back(0);
+ std::vector<void*> inner_tuple_element_addresses;
+ for (int64 i = 0; i < tuple_element_count; ++i) {
+ index.back() = i;
+ copy_tuple_contents(&inner_tuple_element_addresses);
+ }
+ index.pop_back();
+
+ // Create a buffer of pointers for non-leaf buffers.
+ CHECK_EQ(tuple_element_count, inner_tuple_element_addresses.size());
+ auto host_size = inner_tuple_element_addresses.size() * sizeof(void*);
+ se::DeviceMemoryBase tuple_address =
+ buffer_allocations.GetDeviceAddress(
+ infeed_slices_.element(index));
+ stream->ThenMemcpy(&tuple_address,
+ inner_tuple_element_addresses.data(), host_size);
+ tuple_element_addresses->push_back(tuple_address.opaque());
+ };
+
std::vector<void*> tuple_element_addresses;
- for (int i = 0; i < ShapeUtil::TupleElementCount(data_shape); ++i) {
- const BufferAllocation::Slice& tuple_element_buffer =
- infeed_slices_.element({0, i});
- se::DeviceMemoryBase tuple_element_address =
- buffer_allocations.GetDeviceAddress(tuple_element_buffer);
-
- InfeedBuffer* buffer = infeed_manager->BlockingDequeueBuffer();
- infeed_buffers.push_back(buffer);
- stream->ThenMemcpy(&tuple_element_address, *(buffer->device_memory()),
- buffer->length());
- tuple_element_addresses.push_back(tuple_element_address.opaque());
- }
- // Transfer the tuple outer buffer.
- auto host_size = tuple_element_addresses.size() * sizeof(void*);
- stream->ThenMemcpy(&data_address, tuple_element_addresses.data(),
- host_size);
- } else {
- InfeedBuffer* buffer = infeed_manager->BlockingDequeueBuffer();
- infeed_buffers.push_back(buffer);
- stream->ThenMemcpy(&data_address, *(buffer->device_memory()),
- buffer->length());
+ copy_tuple_contents(&tuple_element_addresses);
+ CHECK_EQ(1, tuple_element_addresses.size());
}
// Construct top-level tuple of infeed containing the data and the token. Use
// a nullptr for the token, it should never be dereferenced.
- std::vector<void*> infeed_addresses = {data_address.opaque(), nullptr};
+ se::DeviceMemoryBase data_address =
+ buffer_allocations.GetDeviceAddress(infeed_slices_.element({0}));
+ void* infeed_addresses[] = {data_address.opaque(), nullptr};
se::DeviceMemoryBase top_level_address =
buffer_allocations.GetDeviceAddress(infeed_slices_.element({}));
- stream->ThenMemcpy(&top_level_address, infeed_addresses.data(),
- 2 * sizeof(void*));
+ stream->ThenMemcpy(&top_level_address, infeed_addresses, 2 * sizeof(void*));
Status block_status = stream->BlockHostUntilDone();
if (!block_status.ok()) {
@@ -82,8 +99,6 @@ Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
stream, block_status.error_message().c_str());
}
- infeed_manager->ReleaseBuffers(infeed_buffers);
-
VLOG(2) << "Infeeding to GPU complete";
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
index 1963d9eef7..98ba162cd9 100644
--- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
@@ -33,7 +33,7 @@ TEST_F(InstructionFusionTest,
CostlyProducerAndOperandElementReusingConsumerNotFused) {
HloComputation::Builder builder(TestName());
HloInstruction* const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(5)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0(5)));
HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
ShapeUtil::MakeShape(S32, {}), HloOpcode::kExp, const0));
HloInstruction* broadcast2 =
@@ -53,7 +53,7 @@ TEST_F(InstructionFusionTest,
NonCostlyProducerAndOperandElementReusingConsumerFused) {
HloComputation::Builder builder(TestName());
HloInstruction* const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(5)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0(5)));
HloInstruction* negate1 = builder.AddInstruction(HloInstruction::CreateUnary(
ShapeUtil::MakeShape(S32, {}), HloOpcode::kNegate, const0));
HloInstruction* broadcast2 =
@@ -73,7 +73,7 @@ TEST_F(InstructionFusionTest,
CostlyProducerAndNonOperandElementReusingConsumerFused_Reshape) {
HloComputation::Builder builder(TestName());
HloInstruction* const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(5)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0(5)));
HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
ShapeUtil::MakeShape(S32, {}), HloOpcode::kExp, const0));
HloInstruction* reshape2 = builder.AddInstruction(
@@ -92,7 +92,7 @@ TEST_F(InstructionFusionTest,
CostlyProducerAndNonOperandElementReusingConsumerFused_Transpose) {
HloComputation::Builder builder(TestName());
HloInstruction* const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(5)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0(5)));
HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
ShapeUtil::MakeShape(S32, {}), HloOpcode::kExp, const0));
HloInstruction* transpose2 = builder.AddInstruction(
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 80208e1c98..673ba530df 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -28,7 +28,7 @@ limitations under the License.
#include "llvm/IR/Instructions.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
@@ -48,6 +48,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
#include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/memset_thunk.h"
+#include "tensorflow/compiler/xla/service/gpu/outfeed_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h"
#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h"
@@ -79,6 +80,7 @@ namespace gpu {
namespace {
+using llvm_ir::IrArray;
using llvm_ir::IrName;
using tensorflow::gtl::ArraySlice;
using tensorflow::gtl::InlinedVector;
@@ -355,7 +357,8 @@ Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) {
unroll_factor = ComputeMaxUnrollFactor(hlo);
}
- thunk_sequence_->emplace_back(BuildKernelThunk(hlo, unroll_factor));
+ thunk_sequence_->emplace_back(BuildKernelThunk(
+ hlo, /*implements_whole_instruction=*/true, unroll_factor));
return IrEmitter::DefaultAction(hlo);
}
@@ -369,7 +372,8 @@ Status IrEmitterUnnested::HandleDot(HloInstruction* dot) {
thunk_sequence_->emplace_back(BuildGemmThunk(dot));
return Status::OK();
}
- thunk_sequence_->emplace_back(BuildKernelThunk(dot));
+ thunk_sequence_->emplace_back(
+ BuildKernelThunk(dot, /*implements_whole_instruction=*/true));
return IrEmitter::HandleDot(dot);
}
@@ -379,7 +383,8 @@ Status IrEmitterUnnested::HandleConditional(HloInstruction* conditional) {
}
Status IrEmitterUnnested::HandleConvolution(HloInstruction* convolution) {
- thunk_sequence_->emplace_back(BuildKernelThunk(convolution));
+ thunk_sequence_->emplace_back(
+ BuildKernelThunk(convolution, /*implements_whole_instruction=*/true));
return IrEmitter::HandleConvolution(convolution);
}
@@ -586,10 +591,11 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
}
}
CHECK(first_reduce != nullptr);
- thunks.push_back(BuildKernelThunk(fusion));
+ thunks.push_back(
+ BuildKernelThunk(fusion, /*implements_whole_instruction=*/false));
thunk_sequence_->emplace_back(
MakeUnique<SequentialThunk>(std::move(thunks), fusion));
- std::vector<llvm_ir::IrArray> parameter_arrays;
+ std::vector<IrArray> parameter_arrays;
for (HloInstruction* operand : fusion->operands()) {
parameter_arrays.push_back(GetIrArray(*operand, *fusion));
}
@@ -660,8 +666,9 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
// touching the un-updated elements.
// Set up kernel thunk and fused ir emitter.
- thunk_sequence_->emplace_back(BuildKernelThunk(fusion));
- std::vector<llvm_ir::IrArray> operand_arrays;
+ thunk_sequence_->emplace_back(
+ BuildKernelThunk(fusion, /*implements_whole_instruction=*/true));
+ std::vector<IrArray> operand_arrays;
for (HloInstruction* operand : fusion->operands()) {
operand_arrays.push_back(GetIrArray(*operand, *fusion));
}
@@ -674,7 +681,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
// Array to write into. Because this is an in-place operation, this is the
// same as operand 0's array.
- llvm_ir::IrArray output_array = GetIrArray(*fusion, *fusion);
+ IrArray output_array = GetIrArray(*fusion, *fusion);
LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
update_shape, ir_emitter_context_->device_description());
@@ -687,314 +694,25 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
fusion, operand_arrays, output_array, &elemental_emitter,
launch_dimensions, &ir_builder_);
}
+
if (ImplementedAsGemm(*fusion)) {
thunk_sequence_->emplace_back(BuildGemmThunk(fusion));
return Status::OK();
}
- CHECK(fusion->fusion_kind() == HloInstruction::FusionKind::kLoop);
- int unroll_factor = ComputeMaxUnrollFactor(fusion);
-
- thunk_sequence_->emplace_back(BuildKernelThunk(fusion, unroll_factor));
- return IrEmitter::HandleFusion(fusion);
-}
-
-namespace {
-
-// Returns the indices of the first elements of all consecutive subarrays of the
-// given array. For example:
-// ConsecutiveSegments({m, m+1, m+2, n, k, k+1}) = {0, 3, 4}
-std::vector<size_t> ConsecutiveSegments(tensorflow::gtl::ArraySlice<int64> xs) {
- std::vector<size_t> is = {0};
- for (size_t i = 1; i < xs.size(); ++i) {
- if (1 != xs[i] - xs[i - 1]) {
- is.push_back(i);
- }
- }
- return is;
-}
-
-// Merges the sequences of dimensions of the given shape which start at the
-// given indices `segs`.
-Shape MergeDimensions(tensorflow::gtl::ArraySlice<size_t> segs,
- const Shape& shape) {
- std::vector<int64> dimensions;
- for (size_t i = 1; i <= segs.size(); ++i) {
- dimensions.push_back(std::accumulate(
- shape.dimensions().begin() + segs[i - 1],
- shape.dimensions().begin() +
- (segs.size() == i ? shape.dimensions().size() : segs[i]),
- 1, std::multiplies<int64>()));
- }
- return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(),
- dimensions);
-}
+ CHECK_EQ(fusion->fusion_kind(), HloInstruction::FusionKind::kLoop);
-// Returns whether the given shapes and permutation are a 0-2-1 transpose, and
-// if so, the normalized and rank-reduced shapes. The shapes must have the same
-// dimensions, so this considers layout only.
-//
-// This function recognizes higher-rank transposes which are elementwise
-// equivalent to a 0-2-1 transpose.
-std::tuple<bool, Shape, Shape> IsTranspose021(const Shape& a, const Shape& b) {
- CHECK(ShapeUtil::Compatible(a, b));
- std::vector<int64> perm(a.dimensions().size());
- {
- auto layout_a_orig = LayoutUtil::MinorToMajor(a);
- std::vector<int64> layout_a(layout_a_orig.rbegin(), layout_a_orig.rend());
- auto layout_b_orig = LayoutUtil::MinorToMajor(b);
- std::vector<int64> layout_b(layout_b_orig.rbegin(), layout_b_orig.rend());
- for (size_t i = 0; i < perm.size(); ++i) {
- perm[i] = PositionInContainer(layout_b, layout_a[i]);
- }
+ if (CheckAndEmitHloWithTile021(fusion)) {
+ return Status::OK();
}
- auto segs = ConsecutiveSegments(perm);
- Shape norm_a =
- ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a);
- Shape norm_b =
- ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(b);
- if (3 == segs.size() && 0 == perm[0]) {
- Shape reduced_a = MergeDimensions(segs, norm_a);
- Shape reduced_b = ShapeUtil::MakeShapeWithDescendingLayout(
- b.element_type(),
- Permute({0, 2, 1}, AsInt64Slice(reduced_a.dimensions())));
- return std::make_tuple(true, reduced_a, reduced_b);
- }
- return std::make_tuple(false, ShapeUtil::MakeNil(), ShapeUtil::MakeNil());
-}
-// Returns whether the given shapes are potentially of a 0-2-1 transpose.
-// As 0-2-1 is a self-inverse permutation, which shape is input or output is
-// arbitrary.
-bool AreShapesForTranspose021(const Shape& a, const Shape& b) {
- return 3 == b.dimensions().size() &&
- ShapeUtil::Compatible(
- ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a),
- ShapeUtil::PermuteDimensions(
- {0, 2, 1},
- ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
- b)));
-}
-
-// Emits a tiled 0-2-1 transpose, assuming both input and output lain out from
-// major to minor. The x- and y- dimensions are tiled in square tiles of edge
-// length `tile_size`. Each thread block of `tile_size` x `num_rows` threads
-// transposes one tile: each thread copies a row from the input to a shared
-// memory tile, then copies a column from the shared memory tile to the output.
-//
-// `tile_size` should usually be same as warp size.
-//
-// Returns (number of tiles = number of thread blocks needed).
-//
-// TODO(b/33320379): Here each block transposes 1 tile. It may be more efficient
-// to launch fewer blocks so each transposes many tiles, and
-// in any case, the number of blocks we can launch is limited.
-//
-// This is the same algorithm in CUDA:
-// https://github.com/tensorflow/tensorflow/blob/d2693c8a70567cc78b2e8a9ac8020d321620ca83/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc#L189
-int64 EmitTranspose021Tiled(llvm_ir::IrArray input, llvm_ir::IrArray output,
- const int64 tile_size, const int64 num_rows,
- llvm::IRBuilder<>* builder) {
- // Adds `addend` to the given `dim` of `index`.
- auto offset_dim = [builder](llvm_ir::IrArray::Index index,
- llvm::Value* addend, int64 dim) {
- index[dim] = builder->CreateAdd(index[dim], addend);
- return index;
- };
-
- CHECK(AreShapesForTranspose021(input.GetShape(), output.GetShape()));
-
- Shape input_shape =
- ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
- input.GetShape());
- Shape output_shape =
- ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
- output.GetShape());
- input = input.CastToShape(input_shape, builder);
- output = output.CastToShape(output_shape, builder);
-
- llvm::Type* tile_type = llvm::ArrayType::get(
- llvm::ArrayType::get(input.GetElementLlvmType(), tile_size),
- // One extra here to avoid share memory bank conflict
- tile_size + 1);
- auto* tile = new llvm::GlobalVariable(
- *builder->GetInsertBlock()->getParent()->getParent(), tile_type,
- /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage,
- llvm::UndefValue::get(tile_type), "tile", nullptr,
- llvm::GlobalValue::NotThreadLocal,
- /*AddressSpace=*/3 /* GPU shared memory */);
-
- // let x = threadIdx.x
- llvm::Value* x = llvm_ir::EmitCallToIntrinsic(
- llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, builder);
- llvm_ir::AddRangeMetadata(0, num_rows * tile_size,
- static_cast<llvm::Instruction*>(x));
- x = builder->CreateIntCast(x, builder->getInt64Ty(), /*isSigned=*/true,
- "thread.id.x");
-
- // computing logical thread ids
- // logical_x = x % tile_size
- auto logical_x = builder->CreateURem(x, builder->getInt64(tile_size));
-
- // logical_y = x / tile_size
- auto logical_y = builder->CreateUDiv(x, builder->getInt64(tile_size));
-
- // `emit_cp` emits equivalent to following pseudocode:
- // if (tile_size == tile_width && tile_size == tile_height) {
- // unroll for (i in range(0, tile_size, num_rows)) {
- // emit_cp_element(index + {0, i, 0}, y + logical_y);
- // }
- // } else if (x < tile_width) {
- // tile_height_upperbound = ceil(tile_height / num_rows) * num_rows;
- // for (i in range(0, tile_height_upperbound, num_rows)) {
- // y_loc = i + logical_y;
- // if (y_loc < tile_height)
- // emit_cp_element(index + {0, i, 0}, y_loc);
- // }
- // }
- //
- // We use this to emit both the copy from input to tile and the copy from tile
- // to output.
- //
- // `index` is the origin of the row or column in the input or output array.
- //
- // `emit_cp_element(index, y)` emits code to copy a single element between the
- // tile and the input or output array, where `y` is the `y`-position in the
- // tile, whether which is row or column is a function of whether we're copying
- // from input or to output, and `index` is the index into the input or output
- // array.
- auto emit_cp_tile = [builder, tile_size, &offset_dim, num_rows, logical_x,
- logical_y](
- std::function<void(const llvm_ir::IrArray::Index&,
- llvm::Value*)>
- emit_cp_element,
- llvm::Value* tile_width, llvm::Value* tile_height,
- const llvm_ir::IrArray::Index& index,
- const string& loop_name) {
- llvm_ir::LlvmIfData if_not_last_row = llvm_ir::EmitIfThenElse(
- builder->CreateAnd(
- builder->CreateICmpEQ(builder->getInt64(tile_size), tile_width),
- builder->CreateICmpEQ(builder->getInt64(tile_size), tile_height)),
- "not_last_row", builder);
- builder->SetInsertPoint(if_not_last_row.true_block->getTerminator());
- for (int64 i = 0; i < tile_size; i += num_rows) {
- auto source_idx = offset_dim(index, builder->getInt64(i), /*dim=*/1);
- auto y_loc = builder->CreateAdd(builder->getInt64(i), logical_y);
- emit_cp_element(source_idx, y_loc);
- }
- builder->SetInsertPoint(if_not_last_row.false_block->getTerminator());
- llvm_ir::LlvmIfData if_in_tile = llvm_ir::EmitIfThenElse(
- builder->CreateICmpULT(logical_x, tile_width), "x_in_tile", builder);
- builder->SetInsertPoint(if_in_tile.true_block->getTerminator());
-
- // tile_height_upper_bound = ceil(tile_height / num_rows) * num_rows
- auto tile_height_upper_bound = builder->CreateMul(
- builder->CreateUDiv(
- builder->CreateAdd(tile_height, builder->getInt64(num_rows - 1)),
- builder->getInt64(num_rows)),
- builder->getInt64(num_rows));
-
- auto loop = llvm_ir::ForLoop::EmitForLoop(
- loop_name, builder->getInt64(0), tile_height_upper_bound,
- builder->getInt64(num_rows), builder);
- llvm_ir::SetToFirstInsertPoint(loop->GetHeaderBasicBlock(), builder);
- builder->SetInsertPoint(loop->GetBodyBasicBlock()->getTerminator());
-
- auto y_loc = builder->CreateAdd(loop->GetIndVarValue(), logical_y);
- auto if_y_in_tile = llvm_ir::EmitIfThenElse(
- builder->CreateICmpULT(y_loc, tile_height), "y_in_tile", builder);
- builder->SetInsertPoint(if_y_in_tile.true_block->getTerminator());
-
- emit_cp_element(offset_dim(index, loop->GetIndVarValue(), /*dim=*/1),
- y_loc);
- builder->SetInsertPoint(if_not_last_row.after_block->getTerminator());
- };
-
- auto input_dims_in_tiles = input_shape.dimensions();
- // Unpermuted dimensions are untiled.
- for (int i = 1; i < 3; ++i) {
- input_dims_in_tiles[i] =
- CeilOfRatio<int64>(input_dims_in_tiles[i], tile_size);
- }
- int64 num_tiles =
- std::accumulate(input_dims_in_tiles.begin(), input_dims_in_tiles.end(), 1,
- std::multiplies<int64>());
- const llvm_ir::IrArray::Index input_tile_index(
- /*linear=*/builder->CreateIntCast(
- llvm_ir::AddRangeMetadata(
- 0, num_tiles,
- static_cast<llvm::Instruction*>(llvm_ir::EmitCallToIntrinsic(
- llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {},
- builder))),
- builder->getInt64Ty(), /*isSigned=*/true, "block.id.x"),
- ShapeUtil::MakeShapeWithDescendingLayout(
- PRED /*arbitrary*/, AsInt64Slice(input_dims_in_tiles)),
- builder);
- const llvm_ir::IrArray::Index input_tile_origin = ({
- llvm_ir::IrArray::Index index = input_tile_index;
- for (int i = 1; i < 3; ++i) {
- index[i] = builder->CreateMul(index[i], builder->getInt64(tile_size),
- "tile_origin." + std::to_string(i));
- }
- index;
- });
- const llvm_ir::IrArray::Index input_index =
- offset_dim(offset_dim(input_tile_origin, logical_x, /*dim=*/2), logical_y,
- /*dim=*/1);
- std::vector<llvm::Value*> tile_dims(input_shape.dimensions().size());
- // Only last row or column may not have full size.
- for (int i = 1; i < 3; ++i) {
- tile_dims[i] = builder->CreateSelect(
- builder->CreateICmpEQ(input_tile_index[i],
- builder->getInt64(input_dims_in_tiles[i] - 1)),
- builder->getInt64(input_shape.dimensions(i) -
- (input_dims_in_tiles[i] - 1) * tile_size),
- builder->getInt64(tile_size), "tile_size");
- }
-
- // Load data from input memory to shared memory tile.
- emit_cp_tile(
- // tile[y, x] = input_array[index]
- [builder, tile, &input, logical_x](const llvm_ir::IrArray::Index& index,
- llvm::Value* y) {
- builder->CreateStore(
- input.EmitReadArrayElement(index, builder, "input_element"),
- builder->CreateGEP(tile, {builder->getInt64(0), y, logical_x}));
- },
- tile_dims[2], tile_dims[1], input_index, "input");
+ int unroll_factor = ComputeMaxUnrollFactor(fusion);
- // Wait for all threads to reach this point, lest we copy a value from tile to
- // output before the other thread copies it from input to tile.
- // This is `__syncthreads` in CUDA.
- llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, builder);
-
- const llvm_ir::IrArray::Index output_tile_index(
- Permute({0, 2, 1}, input_tile_index.multidim()));
- const llvm_ir::IrArray::Index output_tile_origin(
- Permute({0, 2, 1}, input_tile_origin.multidim()));
- const llvm_ir::IrArray::Index output_index =
- offset_dim(offset_dim(output_tile_origin, logical_x, /*dim=*/2),
- logical_y, /*dim=*/1);
-
- // Store data from shared memory tile to output memory.
- emit_cp_tile(
- // output_array[index] = tile[x, y]
- [builder, tile, &output, logical_x](const llvm_ir::IrArray::Index& index,
- llvm::Value* y) {
- output.EmitWriteArrayElement(
- index,
- builder->CreateLoad(
- builder->CreateGEP(tile, {builder->getInt64(0), logical_x, y}),
- "output_element"),
- builder);
- },
- tile_dims[1], tile_dims[2], output_index, "output");
-
- return num_tiles;
+ thunk_sequence_->emplace_back(BuildKernelThunk(
+ fusion, /*implements_whole_instruction=*/true, unroll_factor));
+ return IrEmitter::HandleFusion(fusion);
}
-} // namespace
-
Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) {
if (ImplementedAsHostToDeviceMemcpy(ir_emitter_context_->buffer_assignment(),
*copy)) {
@@ -1006,25 +724,7 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) {
thunk_sequence_->emplace_back(BuildDeviceToDeviceCopyThunk(copy));
return Status::OK();
}
- bool is_transpose_021;
- Shape reduced_input_shape, reduced_output_shape;
- std::tie(is_transpose_021, reduced_input_shape, reduced_output_shape) =
- IsTranspose021(copy->operand(0)->shape(), copy->shape());
- if (is_transpose_021 &&
- reduced_input_shape.dimensions(1) >= kMinDimensionToTransposeTiled &&
- reduced_input_shape.dimensions(2) >= kMinDimensionToTransposeTiled) {
- thunk_sequence_->emplace_back(BuildKernelThunk(copy));
- VLOG(3) << "Emitting tiled 0-2-1 transposition";
- constexpr int64 tile_size = 32;
- constexpr int64 num_rows = 8;
- int64 num_tiles = EmitTranspose021Tiled(
- GetIrArray(*copy->operand(0), *copy)
- .CastToShape(reduced_input_shape, &ir_builder_),
- GetIrArray(*copy, *copy)
- .CastToShape(reduced_output_shape, &ir_builder_),
- tile_size, num_rows, &ir_builder_);
- UpdateLaunchDimensions(LaunchDimensions(num_tiles, num_rows * tile_size),
- LastThunk(), ir_emitter_context_->llvm_module());
+ if (CheckAndEmitHloWithTile021(copy)) {
return Status::OK();
}
@@ -1032,7 +732,7 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) {
}
Status IrEmitterUnnested::EmitExtraOutputsForReduce(
- const HloInstruction* reduce, const llvm_ir::IrArray::Index& index,
+ const HloInstruction* reduce, const IrArray::Index& index,
tensorflow::gtl::ArraySlice<
std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
extra_output_gens) {
@@ -1075,11 +775,9 @@ Status IrEmitterUnnested::EmitReductionToScalar(
tiled_input_shape, ir_emitter_context_->device_description());
llvm::Type* index_ty = GetIndexTypeForKernel(
- reduce,
- launch_dimensions.block_count() * launch_dimensions.threads_per_block(),
- &ir_builder_);
+ reduce, launch_dimensions.launch_bound(), &ir_builder_);
- auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
+ auto index_typed_constant = [&](uint64 c) -> llvm::Constant* {
return llvm::ConstantInt::get(index_ty, c);
};
@@ -1121,8 +819,7 @@ Status IrEmitterUnnested::EmitReductionToScalar(
// // and threads_per_block is a multiple of warpSize.
// reduce_kernel<<<num_blocks, threads_per_block>>>();
//
- auto loop_body_emitter =
- [=](const llvm_ir::IrArray::Index& tile_index) -> Status {
+ auto loop_body_emitter = [=](const IrArray::Index& tile_index) -> Status {
const int num_reduces = reducers.size();
llvm::Type* element_ir_type =
llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_);
@@ -1131,9 +828,8 @@ Status IrEmitterUnnested::EmitReductionToScalar(
llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca(
element_ir_type, /*ArraySize=*/nullptr,
"partial_reduction_result." + llvm::Twine(i));
- TF_ASSIGN_OR_RETURN(
- llvm::Value* const init_ir_value,
- init_value_gens[i](llvm_ir::IrArray::Index(index_ty)));
+ TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value,
+ init_value_gens[i](IrArray::Index(index_ty)));
ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address);
partial_reduction_result_addresses.push_back(
partial_reduction_result_address);
@@ -1145,21 +841,22 @@ Status IrEmitterUnnested::EmitReductionToScalar(
// Emit an inner for-loop that reduces the elements in the tile.
auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status {
std::unique_ptr<llvm_ir::ForLoop> tile_element_loop =
- llvm_ir::ForLoop::EmitForLoop(
- "element_id_in_tile", index_typed_const(0),
- index_typed_const(kTileSize), index_typed_const(1), &ir_builder_);
+ llvm_ir::ForLoop::EmitForLoop("element_id_in_tile",
+ index_typed_constant(0),
+ index_typed_constant(kTileSize),
+ index_typed_constant(1), &ir_builder_);
// Emit the body of the partial reduction loop.
llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(),
&ir_builder_);
llvm::Value* x = ir_builder_.CreateNSWAdd(
- ir_builder_.CreateNSWMul(x_in_tiles, index_typed_const(kTileSize)),
+ ir_builder_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileSize)),
tile_element_loop->GetIndVarValue());
// Unless we know the tile is entirely in bounds, we have to emit a
// x-in-bounds check before reading from the input.
if (!tile_in_bounds) {
llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
- ir_builder_.CreateICmpULT(x, index_typed_const(num_elems)),
+ ir_builder_.CreateICmpULT(x, index_typed_constant(num_elems)),
"x_in_bounds", &ir_builder_);
// Emit code that reads the input element and accumulates it to
@@ -1167,7 +864,7 @@ Status IrEmitterUnnested::EmitReductionToScalar(
llvm_ir::SetToFirstInsertPoint(if_data.true_block, &ir_builder_);
}
- llvm_ir::IrArray::Index input_index(
+ IrArray::Index input_index(
/*linear=*/x, input_shape, &ir_builder_);
llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type);
for (int i = 0; i != num_reduces; ++i) {
@@ -1185,12 +882,12 @@ Status IrEmitterUnnested::EmitReductionToScalar(
// x_end = kTileSize + x_in_tiles * kTileSize, i.e., the location that's
// immediately beyond the tile.
llvm::Value* x_end = ir_builder_.CreateNSWAdd(
- index_typed_const(kTileSize),
- ir_builder_.CreateNSWMul(x_in_tiles, index_typed_const(kTileSize)));
+ index_typed_constant(kTileSize),
+ ir_builder_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileSize)));
// The tile is entirely in bound if all_threads_in_bounds or
// x_end <= num_elems.
llvm::Value* tile_in_bounds = ir_builder_.CreateOr(
- ir_builder_.CreateICmpULE(x_end, index_typed_const(num_elems)),
+ ir_builder_.CreateICmpULE(x_end, index_typed_constant(num_elems)),
ir_builder_.getInt1(all_threads_in_bounds));
llvm_ir::LlvmIfData if_tile_in_bounds_data =
llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &ir_builder_);
@@ -1241,9 +938,9 @@ Status IrEmitterUnnested::EmitReductionToScalar(
// lane 0 (which holds the partially accumulated result for its warp) to the
// output element.
llvm::Value* lane_id = ir_builder_.CreateURem(
- x_in_tiles, index_typed_const(kWarpSize), "lane_id");
+ x_in_tiles, index_typed_constant(kWarpSize), "lane_id");
llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse(
- ir_builder_.CreateICmpEQ(lane_id, index_typed_const(0)),
+ ir_builder_.CreateICmpEQ(lane_id, index_typed_constant(0)),
"lane_id_is_zero", &ir_builder_);
llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block,
&ir_builder_);
@@ -1252,7 +949,7 @@ Status IrEmitterUnnested::EmitReductionToScalar(
llvm::Value* output_address =
GetIrArray(*output, *output, reduce_output_shapes[i])
.EmitArrayElementAddress(
- llvm_ir::IrArray::Index(
+ IrArray::Index(
/*linear=*/ir_builder_.getInt64(0),
ShapeUtil::GetSubshape(output->shape(),
reduce_output_shapes[i]),
@@ -1311,7 +1008,7 @@ Status IrEmitterUnnested::EmitColumnReduction(
// TODO(b/110211620): Convert to use i32 index_type when it is possible.
llvm::Type* index_ty = ir_builder_.getInt64Ty();
- auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
+ auto index_typed_constant = [&](uint64 c) -> llvm::Constant* {
return llvm::ConstantInt::get(index_ty, c);
};
@@ -1338,8 +1035,7 @@ Status IrEmitterUnnested::EmitColumnReduction(
// }
// AtomicReducer(&output[x], partial_result);
// }
- auto loop_body_emitter =
- [=](const llvm_ir::IrArray::Index& tile_index) -> Status {
+ auto loop_body_emitter = [=](const IrArray::Index& tile_index) -> Status {
const int num_reduces = reducers.size();
// Emit the loop body that reduces one tile.
llvm::Type* element_ir_type =
@@ -1349,9 +1045,8 @@ Status IrEmitterUnnested::EmitColumnReduction(
llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca(
element_ir_type, /*ArraySize=*/nullptr,
"partial_reduction_result." + llvm::Twine(i));
- TF_ASSIGN_OR_RETURN(
- llvm::Value* const init_ir_value,
- init_value_gens[i](llvm_ir::IrArray::Index(index_ty)));
+ TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value,
+ init_value_gens[i](IrArray::Index(index_ty)));
ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address);
partial_reduction_result_addresses.push_back(
partial_reduction_result_address);
@@ -1367,22 +1062,23 @@ Status IrEmitterUnnested::EmitColumnReduction(
auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status {
std::unique_ptr<llvm_ir::ForLoop> tile_element_loop =
- llvm_ir::ForLoop::EmitForLoop(
- "element_id_in_tile", index_typed_const(0),
- index_typed_const(kTileSize), index_typed_const(1), &ir_builder_);
+ llvm_ir::ForLoop::EmitForLoop("element_id_in_tile",
+ index_typed_constant(0),
+ index_typed_constant(kTileSize),
+ index_typed_constant(1), &ir_builder_);
// Emit the body of the partial reduction loop.
llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(),
&ir_builder_);
llvm::Value* y = ir_builder_.CreateNSWAdd(
- ir_builder_.CreateNSWMul(y_in_tiles, index_typed_const(kTileSize)),
+ ir_builder_.CreateNSWMul(y_in_tiles, index_typed_constant(kTileSize)),
tile_element_loop->GetIndVarValue());
// Unless we know the tile is entirely in bounds, we have to emit a
// y-in-bounds check before reading from the input.
if (!tile_in_bounds) {
llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
- ir_builder_.CreateICmpULT(y, index_typed_const(height)),
+ ir_builder_.CreateICmpULT(y, index_typed_constant(height)),
"y_in_bounds", &ir_builder_);
// Emit code that reads the input element and accumulates it to
@@ -1406,9 +1102,9 @@ Status IrEmitterUnnested::EmitColumnReduction(
const Shape input_matrix_shape =
ShapeUtil::MakeShapeWithDescendingLayout(input_shape.element_type(),
{height, width});
- const llvm_ir::IrArray::Index input_matrix_index(
- {y, x}, input_matrix_shape, &ir_builder_);
- const llvm_ir::IrArray::Index input_index =
+ const IrArray::Index input_matrix_index({y, x}, input_matrix_shape,
+ &ir_builder_);
+ const IrArray::Index input_index =
input_matrix_index
.SourceIndexOfReshape(input_matrix_shape,
normalized_input_shape, &ir_builder_)
@@ -1432,10 +1128,10 @@ Status IrEmitterUnnested::EmitColumnReduction(
// y_end = kTileSize + y_in_tiles * kTileSize, i.e., the y location that's
// immediately beyond the tile.
llvm::Value* y_end = ir_builder_.CreateNSWAdd(
- index_typed_const(kTileSize),
- ir_builder_.CreateNSWMul(y_in_tiles, index_typed_const(kTileSize)));
+ index_typed_constant(kTileSize),
+ ir_builder_.CreateNSWMul(y_in_tiles, index_typed_constant(kTileSize)));
llvm::Value* tile_in_bounds = ir_builder_.CreateOr(
- ir_builder_.CreateICmpULE(y_end, index_typed_const(height)),
+ ir_builder_.CreateICmpULE(y_end, index_typed_constant(height)),
ir_builder_.getInt1(height % kTileSize == 0));
// The tile is entirely in bound if "height" is a multiple of kTileSize or
// y_end <= height.
@@ -1459,11 +1155,10 @@ Status IrEmitterUnnested::EmitColumnReduction(
llvm::Value* output_address =
GetIrArray(*output, *output, reduce_output_shapes[i])
.EmitArrayElementAddress(
- llvm_ir::IrArray::Index(
- x,
- ShapeUtil::GetSubshape(output->shape(),
- reduce_output_shapes[i]),
- &ir_builder_),
+ IrArray::Index(x,
+ ShapeUtil::GetSubshape(
+ output->shape(), reduce_output_shapes[i]),
+ &ir_builder_),
&ir_builder_, "output_element_address");
TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation(
*reducers[i], output_address, partial_reduction_result_addresses[i]));
@@ -1629,15 +1324,13 @@ Status IrEmitterUnnested::EmitRowReduction(
LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
tiled_input_shape, ir_emitter_context_->device_description());
llvm::Type* index_ty = GetIndexTypeForKernel(
- reduce,
- launch_dimensions.block_count() * launch_dimensions.threads_per_block(),
- &ir_builder_);
+ reduce, launch_dimensions.launch_bound(), &ir_builder_);
- auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
+ auto index_typed_constant = [&](uint64 c) -> llvm::Constant* {
return llvm::ConstantInt::get(index_ty, c);
};
- auto loop_body_emitter = [=](const llvm_ir::IrArray::Index& tile_index) {
+ auto loop_body_emitter = [=](const IrArray::Index& tile_index) {
const int num_reduces = reducers.size();
llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType(
input_shape.element_type(), ir_emitter_context_->llvm_module());
@@ -1646,9 +1339,8 @@ Status IrEmitterUnnested::EmitRowReduction(
llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca(
element_ir_type, /*ArraySize=*/nullptr,
"partial_reduction_result." + llvm::Twine(i));
- TF_ASSIGN_OR_RETURN(
- llvm::Value* const init_ir_value,
- init_value_gens[i](llvm_ir::IrArray::Index(index_ty)));
+ TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value,
+ init_value_gens[i](IrArray::Index(index_ty)));
ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address);
partial_reduction_result_addresses.push_back(
partial_reduction_result_address);
@@ -1660,20 +1352,20 @@ Status IrEmitterUnnested::EmitRowReduction(
x_tile = ir_builder_.CreateZExtOrTrunc(x_tile, index_ty);
- llvm::Value* warp_id =
- ir_builder_.CreateUDiv(x_tile, index_typed_const(kWarpSize), "warp_id");
- llvm::Value* lane_id =
- ir_builder_.CreateURem(x_tile, index_typed_const(kWarpSize), "lane_id");
+ llvm::Value* warp_id = ir_builder_.CreateUDiv(
+ x_tile, index_typed_constant(kWarpSize), "warp_id");
+ llvm::Value* lane_id = ir_builder_.CreateURem(
+ x_tile, index_typed_constant(kWarpSize), "lane_id");
// The x-location of the last element in this z-x-tile.
// last_x = lane_id + warpSize * (x_tile_size - 1 + warp_id * x_tile_size);
llvm::Value* last_x = ir_builder_.CreateNSWAdd(
lane_id, ir_builder_.CreateNSWMul(
- index_typed_const(kWarpSize),
+ index_typed_constant(kWarpSize),
ir_builder_.CreateNSWAdd(
- index_typed_const(x_tile_size - 1),
+ index_typed_constant(x_tile_size - 1),
ir_builder_.CreateNSWMul(
- warp_id, index_typed_const(x_tile_size)))));
+ warp_id, index_typed_constant(x_tile_size)))));
KernelSupportLibrary ksl(
&ir_builder_,
@@ -1686,19 +1378,19 @@ Status IrEmitterUnnested::EmitRowReduction(
int64 x_tile_loop_bound) -> Status {
auto emit_z_tile_element_loop = [&](llvm::Value* z_indvar) -> Status {
llvm::Value* z = ir_builder_.CreateNSWAdd(
- z_indvar,
- ir_builder_.CreateNSWMul(index_typed_const(z_tile_size), z_tile));
+ z_indvar, ir_builder_.CreateNSWMul(
+ index_typed_constant(z_tile_size), z_tile));
TF_RETURN_IF_ERROR(ksl.For(
"x_tile",
- /*start=*/index_typed_const(0),
- /*end=*/index_typed_const(x_tile_loop_bound),
+ /*start=*/index_typed_constant(0),
+ /*end=*/index_typed_constant(x_tile_loop_bound),
/*step=*/1, [&](llvm::Value* x_indvar) -> Status {
// x = lane_id +
// warpSize * (element_id_in_x_tile + warp_id * x_tile_size);
llvm::Value* x = ir_builder_.CreateNSWAdd(
lane_id,
ir_builder_.CreateNSWMul(
- index_typed_const(kWarpSize),
+ index_typed_constant(kWarpSize),
ir_builder_.CreateNSWAdd(
x_indvar, ir_builder_.CreateNSWMul(
warp_id, llvm::ConstantInt::get(
@@ -1708,9 +1400,9 @@ Status IrEmitterUnnested::EmitRowReduction(
// emit a x-in-bounds check before reading from the input.
if (!x_tile_in_bounds) {
llvm_ir::LlvmIfData if_x_in_bounds_data =
- llvm_ir::EmitIfThenElse(
- ir_builder_.CreateICmpULT(x, index_typed_const(width)),
- "x_in_bounds", &ir_builder_);
+ llvm_ir::EmitIfThenElse(ir_builder_.CreateICmpULT(
+ x, index_typed_constant(width)),
+ "x_in_bounds", &ir_builder_);
// Points ir_builder_ to the then-block.
llvm_ir::SetToFirstInsertPoint(if_x_in_bounds_data.true_block,
&ir_builder_);
@@ -1737,9 +1429,9 @@ Status IrEmitterUnnested::EmitRowReduction(
const Shape input_3d_tensor_shape =
ShapeUtil::MakeShapeWithDescendingLayout(
input_shape.element_type(), {depth, height, width});
- const llvm_ir::IrArray::Index input_3d_tensor_index(
+ const IrArray::Index input_3d_tensor_index(
{z, y, x}, input_3d_tensor_shape, &ir_builder_);
- const llvm_ir::IrArray::Index input_index =
+ const IrArray::Index input_index =
input_3d_tensor_index
.SourceIndexOfReshape(input_3d_tensor_shape,
normalized_input_shape,
@@ -1765,14 +1457,14 @@ Status IrEmitterUnnested::EmitRowReduction(
};
return ksl.For("z_tile",
- /*start=*/index_typed_const(0),
- /*end=*/index_typed_const(z_tile_size),
+ /*start=*/index_typed_constant(0),
+ /*end=*/index_typed_constant(z_tile_size),
/*step=*/1, emit_z_tile_element_loop);
};
llvm::Value* tile_in_bounds = ir_builder_.CreateOr(
ir_builder_.getInt1(width % (x_tile_size * kWarpSize) == 0),
- ir_builder_.CreateICmpULT(last_x, index_typed_const(width)));
+ ir_builder_.CreateICmpULT(last_x, index_typed_constant(width)));
TF_RETURN_IF_ERROR(
ksl.If(tile_in_bounds,
@@ -1826,7 +1518,7 @@ Status IrEmitterUnnested::EmitRowReduction(
// lane 0 (which holds the partially accumulated result for its warp) to the
// output element.
llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse(
- ir_builder_.CreateICmpEQ(lane_id, index_typed_const(0)),
+ ir_builder_.CreateICmpEQ(lane_id, index_typed_constant(0)),
"lane_id_is_zero", &ir_builder_);
llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block,
&ir_builder_);
@@ -1834,11 +1526,10 @@ Status IrEmitterUnnested::EmitRowReduction(
llvm::Value* output_address =
GetIrArray(*output, *output, reduce_output_shapes[i])
.EmitArrayElementAddress(
- llvm_ir::IrArray::Index(
- y,
- ShapeUtil::GetSubshape(output->shape(),
- reduce_output_shapes[i]),
- &ir_builder_),
+ IrArray::Index(y,
+ ShapeUtil::GetSubshape(
+ output->shape(), reduce_output_shapes[i]),
+ &ir_builder_),
&ir_builder_, "output_element_address");
// We don't need to emit atomic operations if there is only one tile of
// results. 'depth' is the z dimension, 'width' is the x dimension.
@@ -1982,23 +1673,25 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) {
BuildInitializerThunk(reduce));
std::vector<std::unique_ptr<Thunk>> thunks;
thunks.push_back(std::move(initializer_thunk));
- thunks.push_back(BuildKernelThunk(reduce));
+ thunks.push_back(
+ BuildKernelThunk(reduce, /*implements_whole_instruction=*/false));
thunk_sequence_->emplace_back(
MakeUnique<SequentialThunk>(std::move(thunks), reduce));
return EmitReductionToVector(
- reduce, input->shape(), {[&](const llvm_ir::IrArray::Index& index) {
+ reduce, input->shape(), {[&](const IrArray::Index& index) {
return GetIrArray(*input, *reduce)
.EmitReadArrayElement(index, &ir_builder_);
}},
- {[&](const llvm_ir::IrArray::Index& index) {
+ {[&](const IrArray::Index& index) {
return GetIrArray(*init_value, *reduce)
.EmitReadArrayElement(index, &ir_builder_);
}},
dimensions_to_reduce, {reducer}, {{}}, {});
}
- thunk_sequence_->emplace_back(BuildKernelThunk(reduce));
+ thunk_sequence_->emplace_back(
+ BuildKernelThunk(reduce, /*implements_whole_instruction=*/true));
return IrEmitter::HandleReduce(reduce);
}
@@ -2027,7 +1720,8 @@ Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) {
tuple_element_buffers, GetAllocationSlice(*tuple), tuple));
return Status::OK();
}
- thunk_sequence_->emplace_back(BuildKernelThunk(tuple));
+ thunk_sequence_->emplace_back(
+ BuildKernelThunk(tuple, /*implements_whole_instruction=*/true));
return IrEmitter::HandleTuple(tuple);
}
@@ -2052,7 +1746,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
BuildInitializerThunk(select_and_scatter));
std::vector<std::unique_ptr<Thunk>> thunks;
thunks.push_back(std::move(initializer_thunk));
- thunks.push_back(BuildKernelThunk(select_and_scatter));
+ thunks.push_back(BuildKernelThunk(select_and_scatter,
+ /*implements_whole_instruction=*/false));
thunk_sequence_->emplace_back(
MakeUnique<SequentialThunk>(std::move(thunks), select_and_scatter));
@@ -2066,7 +1761,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
source->shape(), ir_emitter_context_->device_description());
llvm::Type* index_type = GetIndexTypeForKernel(
select_and_scatter, launch_dimensions.launch_bound(), &ir_builder_);
- auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
+ auto index_typed_constant = [&](uint64 c) -> llvm::Constant* {
return llvm::ConstantInt::get(index_type, c);
};
@@ -2089,8 +1784,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
// selected_index = I
// initialized_flag = true
// output(selected_index) = scatter(output(selected_index), source(S))
- auto loop_body_emitter =
- [=](const llvm_ir::IrArray::Index& source_index) -> Status {
+ auto loop_body_emitter = [=](const IrArray::Index& source_index) -> Status {
// Allocate space to keep the currently selected value, its index, and a
// boolean flag if the value is initialized. The initialized_flag is set
// false.
@@ -2100,7 +1794,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
"selected_value_address", &ir_builder_);
llvm::Value* selected_index_address =
llvm_ir::EmitAllocaAtFunctionEntryWithCount(
- index_type, index_typed_const(rank), "selected_index_address",
+ index_type, index_typed_constant(rank), "selected_index_address",
&ir_builder_);
llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry(
ir_builder_.getInt1Ty(), "initialized_flag_address", &ir_builder_);
@@ -2115,7 +1809,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
window_size.push_back(dim.size());
CHECK_GT(dim.size(), 0);
}
- const llvm_ir::IrArray::Index window_index = window_loops.AddLoopsForShape(
+ const IrArray::Index window_index = window_loops.AddLoopsForShape(
ShapeUtil::MakeShape(operand_element_type, window_size), "window");
llvm_ir::SetToFirstInsertPoint(window_loops.GetInnerLoopBodyBasicBlock(),
&ir_builder_);
@@ -2123,17 +1817,17 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
// Compute the operand index to visit and evaluate the condition whether the
// operand index is within the bounds. The unsigned comparison includes
// checking whether the operand index >= 0.
- llvm_ir::IrArray::Index operand_index(index_type, source_index.size());
+ IrArray::Index operand_index(index_type, source_index.size());
llvm::Value* in_bounds_condition = ir_builder_.getInt1(true);
for (int64 i = 0; i < rank; ++i) {
llvm::Value* strided_index = ir_builder_.CreateNSWMul(
- source_index[i], index_typed_const(window.dimensions(i).stride()));
+ source_index[i], index_typed_constant(window.dimensions(i).stride()));
operand_index[i] = ir_builder_.CreateNSWSub(
ir_builder_.CreateNSWAdd(strided_index, window_index[i]),
- index_typed_const(window.dimensions(i).padding_low()));
+ index_typed_constant(window.dimensions(i).padding_low()));
llvm::Value* index_condition = ir_builder_.CreateICmpULT(
operand_index[i],
- index_typed_const(ShapeUtil::GetDimension(operand->shape(), i)));
+ index_typed_constant(ShapeUtil::GetDimension(operand->shape(), i)));
in_bounds_condition =
ir_builder_.CreateAnd(in_bounds_condition, index_condition);
}
@@ -2151,8 +1845,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
// If the initialized_flag is false, initialize the selected value and index
// with the currently visiting operand.
llvm_ir::SetToFirstInsertPoint(if_initialized.false_block, &ir_builder_);
- const auto save_operand_index = [&](
- const llvm_ir::IrArray::Index& operand_index) {
+ const auto save_operand_index = [&](const IrArray::Index& operand_index) {
for (int64 i = 0; i < rank; ++i) {
llvm::Value* selected_index_address_slot =
ir_builder_.CreateInBoundsGEP(selected_index_address,
@@ -2160,7 +1853,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
ir_builder_.CreateStore(operand_index[i], selected_index_address_slot);
}
};
- llvm_ir::IrArray operand_array = GetIrArray(*operand, *select_and_scatter);
+ IrArray operand_array = GetIrArray(*operand, *select_and_scatter);
llvm::Value* operand_data =
operand_array.EmitReadArrayElement(operand_index, &ir_builder_);
ir_builder_.CreateStore(operand_data, selected_value_address);
@@ -2205,7 +1898,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
// value and the current output value.
llvm_ir::SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(),
&ir_builder_);
- llvm_ir::IrArray::Index selected_index(operand_index.GetType());
+ IrArray::Index selected_index(operand_index.GetType());
for (int64 i = 0; i < rank; ++i) {
llvm::Value* selected_index_address_slot = ir_builder_.CreateInBoundsGEP(
selected_index_address, {ir_builder_.getInt32(i)});
@@ -2260,17 +1953,20 @@ Status IrEmitterUnnested::HandleWhile(HloInstruction* xla_while) {
}
Status IrEmitterUnnested::HandleRng(HloInstruction* random) {
- thunk_sequence_->push_back(BuildKernelThunk(random));
+ thunk_sequence_->push_back(
+ BuildKernelThunk(random, /*implements_whole_instruction=*/true));
return IrEmitter::HandleRng(random);
}
Status IrEmitterUnnested::HandleSelect(HloInstruction* select) {
- thunk_sequence_->push_back(BuildKernelThunk(select));
+ thunk_sequence_->push_back(
+ BuildKernelThunk(select, /*implements_whole_instruction=*/true));
return IrEmitter::HandleSelect(select);
}
Status IrEmitterUnnested::HandleTupleSelect(HloInstruction* tuple_select) {
- thunk_sequence_->push_back(BuildKernelThunk(tuple_select));
+ thunk_sequence_->push_back(
+ BuildKernelThunk(tuple_select, /*implements_whole_instruction=*/true));
return IrEmitter::HandleTupleSelect(tuple_select);
}
@@ -2309,12 +2005,12 @@ Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) {
thunks.push_back(MakeUnique<DeviceToDeviceCopyThunk>(
/*source_address=*/GetAllocationSlice(*crs->operand(i)),
/*destination_buffer=*/tuple_element_buffers.back(),
- /*mem_size=*/ShapeUtil::ByteSizeOf(crs->operand(i)->shape()), crs));
+ /*mem_size=*/ShapeUtil::ByteSizeOf(crs->operand(i)->shape()), nullptr));
}
// Output a tuple of the buffers above.
thunks.push_back(MakeUnique<TupleThunk>(tuple_element_buffers,
- GetAllocationSlice(*crs), crs));
+ GetAllocationSlice(*crs), nullptr));
thunk_sequence_->push_back(
MakeUnique<SequentialThunk>(std::move(thunks), crs));
return Status::OK();
@@ -2329,6 +2025,11 @@ Status IrEmitterUnnested::HandleInfeed(HloInstruction* infeed) {
return Status::OK();
}
+Status IrEmitterUnnested::HandleOutfeed(HloInstruction* outfeed) {
+ thunk_sequence_->emplace_back(BuildOutfeedThunk(outfeed));
+ return Status::OK();
+}
+
// Figures out how to access the buffers for all subshapes of hlo's operands and
// for hlo itself (i.e. all the buffers produced by HLO).
//
@@ -2448,7 +2149,8 @@ GetHloBufferSlices(const HloInstruction* hlo,
}
std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
- const HloInstruction* inst, int unroll_factor) {
+ const HloInstruction* inst, bool implements_whole_instruction,
+ int unroll_factor) {
const BufferAssignment& buffer_assn =
ir_emitter_context_->buffer_assignment();
@@ -2540,7 +2242,8 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
}
return MakeUnique<KernelThunk>(buffers, llvm_ir::AsString(kernel->getName()),
- inst, unroll_factor);
+ implements_whole_instruction ? inst : nullptr,
+ unroll_factor);
}
std::unique_ptr<Thunk> IrEmitterUnnested::BuildHostToDeviceCopyThunk(
@@ -2574,7 +2277,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildInfeedThunk(
ShapeTree<BufferAllocation::Slice> slices(inst->shape());
slices.ForEachMutableElement(
- [this, inst](const ShapeIndex& index, BufferAllocation::Slice* slice) {
+ [&](const ShapeIndex& index, BufferAllocation::Slice* slice) {
*slice = ir_emitter_context_->buffer_assignment()
.GetUniqueSlice(inst, index)
.ConsumeValueOrDie();
@@ -2582,6 +2285,23 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildInfeedThunk(
return MakeUnique<InfeedThunk>(slices, inst);
}
+std::unique_ptr<Thunk> IrEmitterUnnested::BuildOutfeedThunk(
+ const HloInstruction* inst) {
+ CHECK_EQ(HloOpcode::kOutfeed, inst->opcode());
+
+ ShapeTree<BufferAllocation::Slice> slices(inst->operand(0)->shape());
+ slices.ForEachMutableElement(
+ [&](const ShapeIndex& index, BufferAllocation::Slice* slice) {
+ auto status_or_slice =
+ ir_emitter_context_->buffer_assignment().GetUniqueSlice(
+ inst->operand(0), index);
+ if (status_or_slice.ok()) {
+ *slice = status_or_slice.ConsumeValueOrDie();
+ }
+ });
+ return MakeUnique<OutfeedThunk>(std::move(slices), inst);
+}
+
namespace {
double GetScalarConstantAsDouble(const Literal& literal) {
switch (literal.shape().element_type()) {
@@ -2697,6 +2417,11 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
init_value = hlo->operand(init_value->parameter_number());
}
+ // Initializer thunks don't implement a whole instruction, and we want to
+ // profile the whole instruction instead of the individual thunks it consists
+ // of. Therefore we pass nullptr as the HloInstruction* to the thunks we
+ // generate below.
+ //
// In the common case, the initializer is a constant. In this case, emit a
// device-memset call if we can. Currently StreamExecutor only supports
// zeroing and 32-bit memsets.
@@ -2710,7 +2435,8 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
ArraySlice<uint8> literal_bytes(
reinterpret_cast<const uint8*>(literal.untyped_data()), num_bytes);
if (c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) {
- return {MakeUnique<MemzeroThunk>(GetAllocationSlice(*hlo, index), hlo)};
+ return {
+ MakeUnique<MemzeroThunk>(GetAllocationSlice(*hlo, index), nullptr)};
}
// If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by
@@ -2728,7 +2454,7 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
}
uint32 pattern32 = uint32{pattern16} | (uint32{pattern16} << 16);
return {MakeUnique<Memset32BitValueThunk>(
- pattern32, GetAllocationSlice(*hlo, index), hlo)};
+ pattern32, GetAllocationSlice(*hlo, index), nullptr)};
}
// If the literal is an even multiple of 32 bits wide, we can emit a 32-bit
@@ -2739,12 +2465,13 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
uint32 word;
memcpy(&word, literal_bytes.data(), sizeof(word));
return {MakeUnique<Memset32BitValueThunk>(
- word, GetAllocationSlice(*hlo, index), hlo)};
+ word, GetAllocationSlice(*hlo, index), nullptr)};
}
}
// Otherwise fall back to our slow initializer code.
- std::unique_ptr<KernelThunk> kernel_thunk = BuildKernelThunk(hlo);
+ std::unique_ptr<KernelThunk> kernel_thunk =
+ BuildKernelThunk(hlo, /*implements_whole_instruction=*/false);
LaunchDimensions launch_dimensions =
CalculateLaunchDimensions(ShapeUtil::GetSubshape(hlo->shape(), index),
ir_emitter_context_->device_description());
@@ -2756,7 +2483,7 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
TF_RETURN_IF_ERROR(HandleConstant(const_cast<HloInstruction*>(init_value)));
}
TF_RETURN_IF_ERROR(ParallelLoopEmitter(
- [=](const llvm_ir::IrArray::Index& index) {
+ [=](const IrArray::Index& index) {
return GetIrArray(*init_value, *hlo)
.EmitReadArrayElement(index, &ir_builder_);
},
@@ -2951,8 +2678,8 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
&ir_builder_));
}
- // For multiple outputs fusion, we need to emit each operand and the root.
- std::vector<llvm_ir::IrArray> output_arrays;
+ // For multioutput fusion, we need to emit each operand and the root.
+ std::vector<IrArray> output_arrays;
for (int64 i = 0; i < ShapeUtil::TupleElementCount(hlo.shape()); ++i) {
output_arrays.push_back(GetIrArray(hlo, hlo, {i}));
}
@@ -2981,5 +2708,482 @@ Status IrEmitterUnnested::EmitTargetElementLoop(
static_cast<KernelThunk*>(LastThunk()));
}
+int IrEmitterUnnested::ConstructIrArrayForOutputs(
+ const HloInstruction& hlo, std::vector<IrArray>* output_arrays) {
+ int64 num_outputs = 1;
+ if (hlo.IsMultiOutputFusion()) {
+ num_outputs = ShapeUtil::TupleElementCount(hlo.shape());
+ output_arrays->reserve(num_outputs);
+ for (int64 i = 0; i < num_outputs; ++i) {
+ output_arrays->push_back(GetIrArray(hlo, hlo, {i}));
+ }
+ } else {
+ output_arrays->push_back(GetIrArray(hlo, hlo));
+ }
+ return num_outputs;
+}
+
+int IrEmitterUnnested::ConstructIrArrayForInputs(
+ const HloInstruction& hlo, std::vector<IrArray>* param_arrays) {
+ int64 num_params = hlo.operands().size();
+ param_arrays->reserve(num_params);
+ for (const HloInstruction* param : hlo.operands()) {
+ param_arrays->push_back(GetIrArray(*param, hlo));
+ }
+ return num_params;
+}
+
+int IrEmitterUnnested::ConstructOutputReducedShapeAndCastOutputIrArrayToShape(
+ const HloInstruction& hlo, const std::vector<IrArray>& output_arrays,
+ tensorflow::gtl::ArraySlice<int64> reduced_output_dims,
+ std::vector<Shape>* output_reduced_shapes,
+ std::vector<IrArray>* output_in_reduced_shape_arrays) {
+ int64 num_outputs = 1;
+ if (hlo.IsMultiOutputFusion()) {
+ num_outputs = ShapeUtil::TupleElementCount(hlo.shape());
+ output_in_reduced_shape_arrays->reserve(num_outputs);
+ output_reduced_shapes->reserve(num_outputs);
+ for (int64 i = 0; i < num_outputs; ++i) {
+ output_reduced_shapes->push_back(ShapeUtil::MakeShapeWithDescendingLayout(
+ ShapeUtil::GetSubshape(hlo.shape(), {i}).element_type(),
+ reduced_output_dims));
+ output_in_reduced_shape_arrays->push_back(output_arrays[i].CastToShape(
+ (*output_reduced_shapes)[i], &ir_builder_));
+ }
+ } else {
+ output_reduced_shapes->push_back(ShapeUtil::MakeShapeWithDescendingLayout(
+ hlo.shape().element_type(), reduced_output_dims));
+ output_in_reduced_shape_arrays->push_back(output_arrays[0].CastToShape(
+ (*output_reduced_shapes)[0], &ir_builder_));
+ }
+ return num_outputs;
+}
+
+int IrEmitterUnnested::ConstructInputReducedShapeAndCastInputIrArrayToShape(
+ const HloInstruction& hlo, const std::vector<IrArray>& param_arrays,
+ const std::vector<llvm::Value*>& param_buffers,
+ tensorflow::gtl::ArraySlice<int64> reduced_output_dims,
+ std::vector<Shape>* param_reduced_shapes,
+ std::vector<IrArray>* param_in_reduced_shape_arrays) {
+ int64 num_params = hlo.operands().size();
+ param_in_reduced_shape_arrays->reserve(num_params);
+ param_reduced_shapes->reserve(num_params);
+ for (int64 id = 0; id < num_params; ++id) {
+ if (param_buffers[id] == nullptr) {
+ param_reduced_shapes->push_back(Shape());
+ param_in_reduced_shape_arrays->push_back(IrArray());
+ continue;
+ }
+ const HloInstruction* param = hlo.operand(id);
+ param_reduced_shapes->push_back(ShapeUtil::MakeShapeWithDescendingLayout(
+ param->shape().element_type(),
+ Permute({0, 2, 1}, reduced_output_dims)));
+ param_in_reduced_shape_arrays->push_back(param_arrays[id].CastToShape(
+ (*param_reduced_shapes)[id], &ir_builder_));
+ }
+ return num_params;
+}
+
+namespace {
+
+// Reads thread_idx.x and converts it to a (y,x) coordinate, assuming that the
+// thread lives within a square tile of size tile_size (so thread blocks are of
+// size tile_size * tile_size).
+std::tuple<llvm::Value*, llvm::Value*> CalculateYXCoordinateWithinTile(
+ llvm::IRBuilder<>* builder, llvm::Value* tile_size,
+ int64 threads_per_tile) {
+ // Calculate the starting element coordinate within a tile for the current
+ // thread, (y, x) from thread_id.
+ llvm::Value* thread_id = llvm_ir::EmitCallToIntrinsic(
+ llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, builder);
+ llvm_ir::AddRangeMetadata(0, threads_per_tile,
+ llvm::cast<llvm::Instruction>(thread_id));
+ thread_id = builder->CreateIntCast(thread_id, tile_size->getType(),
+ /*isSigned=*/true, "thread.id.x");
+ auto x = builder->CreateURem(thread_id, tile_size);
+ auto y = builder->CreateUDiv(thread_id, tile_size);
+ return std::make_tuple(y, x);
+}
+
+// Reads block_idx.x, casts it to type index_ty, and adds the assumption that
+// it's in the range [0, num_blocks].
+llvm::Value* GetBlockIdx(llvm::IRBuilder<>* builder, llvm::Type* index_ty,
+ int64 num_blocks) {
+ llvm::Value* block_id = llvm_ir::EmitCallToIntrinsic(
+ llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, builder);
+ llvm_ir::AddRangeMetadata(0, num_blocks,
+ llvm::cast<llvm::Instruction>(block_id));
+ return builder->CreateIntCast(block_id, index_ty, /*isSigned=*/true,
+ "block.id.x");
+}
+
+// Emits code to process up to (tile_size/num_rows) elements in a tile, given
+// `emit_elem_function` is the function to emit code to process one element, `y`
+// and `x` are the coordinates for the first element to process, and `index` is
+// the index for the origin of the tile. Emits bounds check to ensure that each
+// processed element is within the boundary defined by `tile_width` and
+// `tile_height`.
+void EmitTiledElementalCodeWithBoundsCheck(
+ int64 tile_size, int64 num_rows, const IrArray::Index& index,
+ const string& loop_name, KernelSupportLibrary* ksl,
+ llvm::IRBuilder<>* builder, llvm::Value* y, llvm::Value* x,
+ llvm::Value* tile_width, llvm::Value* tile_height,
+ const std::function<void(const IrArray::Index&, llvm::Value*)>&
+ emit_elem_function) {
+ llvm::Type* index_ty = tile_width->getType();
+ // Emits a constant value with index type.
+ auto index_typed_constant = [&](uint64 c) -> llvm::Constant* {
+ return llvm::ConstantInt::get(index_ty, c);
+ };
+ // Adds `addend` to the given `dim` of `index`.
+ auto offset_dim = [&](IrArray::Index index, llvm::Value* addend, int64 dim) {
+ index[dim] = builder->CreateAdd(index[dim], addend);
+ return index;
+ };
+
+ auto emit_full_tile = [&] {
+ for (int64 i = 0; i < tile_size; i += num_rows) {
+ auto source_idx = offset_dim(index, index_typed_constant(i), /*dim=*/1);
+ auto y_loc = builder->CreateAdd(index_typed_constant(i), y);
+ emit_elem_function(source_idx, y_loc);
+ }
+ };
+
+ auto emit_last_row = [&] {
+ ksl->IfReturnVoid("x_in_tile", builder->CreateICmpULT(x, tile_width), [&] {
+ // tile_height_upper_bound =
+ // ceil(tile_height / num_rows) * num_rows
+ auto tile_height_upper_bound = builder->CreateMul(
+ builder->CreateUDiv(
+ builder->CreateAdd(tile_height,
+ index_typed_constant(num_rows - 1)),
+ index_typed_constant(num_rows)),
+ index_typed_constant(num_rows));
+ ksl->ForReturnVoid(
+ loop_name, /*start=*/index_typed_constant(0),
+ /*end=*/tile_height_upper_bound,
+ /*step=*/index_typed_constant(num_rows), [&](llvm::Value* y_indvar) {
+ auto y_loc = builder->CreateAdd(y_indvar, y);
+ ksl->IfReturnVoid(
+ "y_in_tile", builder->CreateICmpULT(y_loc, tile_height), [&] {
+ emit_elem_function(offset_dim(index, y_indvar, /*dim=*/1),
+ y_loc);
+ });
+ });
+ });
+ };
+ ksl->IfReturnVoid(
+ "full_tile",
+ builder->CreateAnd(
+ builder->CreateICmpEQ(index_typed_constant(tile_size), tile_width),
+ builder->CreateICmpEQ(index_typed_constant(tile_size), tile_height)),
+ emit_full_tile, emit_last_row);
+}
+} // namespace
+
+// Emits a kernel for the given hlo instruction using a tiled 0-2-1 transpose
+// algorithm to improve the memory access patterns for the input parameters
+// which have a shape that is a 0-2-1 transpose of the output tensors.
+//
+// For the purpose of tiling, the output tensors have a logical shape of three
+// components 0-2-1 while the relevant input parameters have a logical shape of
+// three components 0-1-2 in the order major to minor. The x- and y- dimensions
+// of the tensors are tiled in square tiles of edge length `kTileSize`. Each
+// thread block of `kTileSize` x `kNumRows` threads transposes one tile: each
+// thread copies kTileSize/kNumRows elements from the input to a shared memory
+// tile, then the otherwise "regular hlo kernel" reads from the shared memory
+// instead of the original input.
+//
+// This is similar to the following CUDA algorithm in TensorFlow:
+// https://goo.gl/MStRV6.
+//
+// `kTileSize` should usually be same as warp size. We currently choose 32 for
+// `kTileSize` and 4 for `kNumRows`. The CUDA algorithm uses 8 for `kNumRows`.
+//
+// TODO(b/33320379): Here each block transposes 1 tile. It may be more efficient
+// to launch fewer blocks so each transposes many tiles.
+LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
+ HloInstruction* hlo, tensorflow::gtl::ArraySlice<int64> reduced_output_dims,
+ tensorflow::gtl::ArraySlice<int64> tiled_param_ids) {
+ // Parameters for the tiling algorithm.
+ constexpr int64 kTileSize = 32;
+ constexpr int64 kNumRows = 4;
+ constexpr int64 kThreadsPerTile = kTileSize * kNumRows;
+
+ // Construct IrArrays for the inputs and outputs.
+ std::vector<IrArray> output_arrays;
+ int64 num_outputs = ConstructIrArrayForOutputs(*hlo, &output_arrays);
+ std::vector<IrArray> param_arrays;
+ int64 num_params = ConstructIrArrayForInputs(*hlo, &param_arrays);
+
+ // Allocate shared memory buffers to store the tiled inputs.
+ std::vector<llvm::Value*> param_shmem_buffers(num_params, nullptr);
+ for (int64 id : tiled_param_ids) {
+ const HloInstruction* param = hlo->operand(id);
+ // Add 1 to the minor dimension to reduce shared memory bank conflicts.
+ llvm::Type* tile_type = llvm::ArrayType::get(
+ llvm::ArrayType::get(llvm_ir::PrimitiveTypeToIrType(
+ param->shape().element_type(), module_),
+ kTileSize + 1),
+ kTileSize);
+ const int kNVPTXSharedMemoryAddrSpace = 3;
+ auto* tile_base_ptr = new llvm::GlobalVariable(
+ *ir_builder_.GetInsertBlock()->getParent()->getParent(), tile_type,
+ /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage,
+ llvm::UndefValue::get(tile_type),
+ llvm_ir::AsStringRef(IrName(hlo, StrCat("tile", id))), nullptr,
+ llvm::GlobalValue::NotThreadLocal, kNVPTXSharedMemoryAddrSpace);
+ param_shmem_buffers[id] = tile_base_ptr;
+ VLOG(3) << "Added shmem buffer for parameter " << id << ": "
+ << llvm_ir::DumpToString(*tile_base_ptr);
+ }
+
+ // The 0-2-1 shape of the tiling scheme is the reduced shape of the HLO result
+ // for the purpose of tiling. Calculate the logical output dimensions in the
+ // tile from the reduced output dimensions.
+ std::vector<int64> output_dims_in_tiles = std::vector<int64>(
+ reduced_output_dims.begin(), reduced_output_dims.end());
+ CHECK_EQ(output_dims_in_tiles.size(), 3);
+ for (int i = 1; i < 3; ++i) {
+ output_dims_in_tiles[i] =
+ CeilOfRatio<int64>(output_dims_in_tiles[i], kTileSize);
+ }
+ const int64 num_tiles =
+ c_accumulate(output_dims_in_tiles, 1, std::multiplies<int64>());
+ LaunchDimensions launch_dimensions(num_tiles, kThreadsPerTile);
+
+ llvm::Type* index_ty = GetIndexTypeForKernel(
+ hlo, launch_dimensions.launch_bound(), &ir_builder_);
+ auto index_typed_constant = [&](uint64 c) -> llvm::Constant* {
+ return llvm::ConstantInt::get(index_ty, c);
+ };
+
+ // Cast each output IrArray to its corresponding reduced shape and keep the
+ // reduced shape live during IR emission.
+ std::vector<IrArray> output_in_reduced_shape_arrays;
+ std::vector<Shape> output_reduced_shapes;
+ CHECK_EQ(ConstructOutputReducedShapeAndCastOutputIrArrayToShape(
+ *hlo, output_arrays, reduced_output_dims, &output_reduced_shapes,
+ &output_in_reduced_shape_arrays),
+ num_outputs);
+
+ // For each tiled parameter, cast its input IrArray to the corresponding
+ // reduced shape and keep the reduced shape live during IR emission.
+ std::vector<IrArray> param_in_reduced_shape_arrays;
+ std::vector<Shape> param_reduced_shapes;
+ CHECK_EQ(ConstructInputReducedShapeAndCastInputIrArrayToShape(
+ *hlo, param_arrays, param_shmem_buffers, reduced_output_dims,
+ &param_reduced_shapes, &param_in_reduced_shape_arrays),
+ num_params);
+
+ // Calculate the starting element coordinate within a tile for the current
+ // thread, (y, x) from thread_id.
+ llvm::Value* x;
+ llvm::Value* y;
+ std::tie(y, x) = CalculateYXCoordinateWithinTile(
+ &ir_builder_, index_typed_constant(kTileSize), kThreadsPerTile);
+
+ // Calculate the index for the current output tile from block_id.
+ const IrArray::Index output_tile_index(
+ GetBlockIdx(&ir_builder_, index_ty, num_tiles),
+ ShapeUtil::MakeShapeWithDescendingLayout(PRED /*arbitrary*/,
+ output_dims_in_tiles),
+ &ir_builder_);
+
+ // Output tile origin is the index for the first element of the current output
+ // tile.
+ const IrArray::Index output_tile_origin = [&] {
+ IrArray::Index index = output_tile_index;
+ for (int i = 1; i < 3; ++i) {
+ index[i] = ir_builder_.CreateMul(output_tile_index[i],
+ index_typed_constant(kTileSize),
+ "tile_origin." + std::to_string(i));
+ }
+ return index;
+ }();
+
+ // Calculate the input tile origin from the output tile origin.
+ const IrArray::Index input_tile_origin(
+ Permute({0, 2, 1}, output_tile_origin.multidim()));
+
+ // Calculate the current output tile bounds in each of the logical dimensions.
+ std::vector<llvm::Value*> output_tile_bounds(3);
+ for (int i = 1; i < 3; ++i) {
+ // Only last row or column may not have full size.
+ output_tile_bounds[i] = ir_builder_.CreateSelect(
+ ir_builder_.CreateICmpEQ(
+ output_tile_index[i],
+ index_typed_constant(output_dims_in_tiles[i] - 1)),
+ index_typed_constant(reduced_output_dims[i] -
+ (output_dims_in_tiles[i] - 1) * kTileSize),
+ index_typed_constant(kTileSize), "kTileSize");
+ }
+
+ KernelSupportLibrary ksl(&ir_builder_, llvm_ir::UnrollMode::kDefaultUnroll);
+
+ // Curry a few parameters to EmitTiledElementalCodeWithBoundsCheck.
+ auto emit_tiled_elemental_code_with_bounds_check =
+ [&](const IrArray::Index& index, const string& loop_name,
+ llvm::Value* tile_width, llvm::Value* tile_height,
+ const std::function<void(const IrArray::Index&, llvm::Value*)>&
+ emit_elem_function) {
+ EmitTiledElementalCodeWithBoundsCheck(
+ kTileSize, kNumRows, index, loop_name, &ksl, &ir_builder_, y, x,
+ tile_width, tile_height, emit_elem_function);
+ };
+
+ // Adds `addend` to the given `dim` of `index`.
+ auto offset_dim = [&](IrArray::Index index, llvm::Value* addend, int64 dim) {
+ index[dim] = ir_builder_.CreateAdd(index[dim], addend);
+ return index;
+ };
+ const IrArray::Index input_index =
+ offset_dim(offset_dim(input_tile_origin, x, /*dim=*/2), y, /*dim=*/1);
+
+ // Copy input parameter values to shared memory buffers:
+ // tile[y, x] = input[index]
+ emit_tiled_elemental_code_with_bounds_check(
+ input_index, "input", output_tile_bounds[1], output_tile_bounds[2],
+ [&](const IrArray::Index& index, llvm::Value* y_loc) {
+ for (int64 id : tiled_param_ids) {
+ IrArray& input_in_logical_shape = param_in_reduced_shape_arrays[id];
+ llvm::Value* shmem_buffer = param_shmem_buffers[id];
+ // TODO(jlebar): Add AA metadata to this store. Tile buffers are
+ // global variables, so LLVM can't infer much about it.
+ ir_builder_.CreateStore(
+ input_in_logical_shape.EmitReadArrayElement(index, &ir_builder_,
+ "input_element"),
+ ir_builder_.CreateGEP(shmem_buffer,
+ {index_typed_constant(0), y_loc, x}));
+ }
+ });
+
+ // Wait for all threads to reach this point, lest we copy a value from tile to
+ // output before the other thread copies it from input to tile.
+ // This is `__syncthreads` in CUDA.
+ llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {},
+ &ir_builder_);
+
+ llvm_ir::TiledParameterInfo tiled_param_info(param_shmem_buffers, y, x);
+
+ const IrArray::Index output_index =
+ offset_dim(offset_dim(output_tile_origin, x, /*dim=*/2), y, /*dim=*/1);
+
+ // Write to output[index] by emitting code like normal, except that values for
+ // the tiled parameters are read from the shmem buffers.
+ if (hlo->opcode() == HloOpcode::kCopy) {
+ emit_tiled_elemental_code_with_bounds_check(
+ output_index, "output", output_tile_bounds[2], output_tile_bounds[1],
+ [&](const IrArray::Index& index, llvm::Value* y_loc) {
+ // TODO(jlebar): Add AA metadata to this load.
+ llvm::Instruction* load_from_shmem_buffer = ir_builder_.CreateLoad(
+ ir_builder_.CreateGEP(param_shmem_buffers[0],
+ {ir_builder_.getInt64(0), x, y_loc}),
+ "output_element");
+ output_in_reduced_shape_arrays[0].EmitWriteArrayElement(
+ index, load_from_shmem_buffer, &ir_builder_);
+ });
+ } else {
+ CHECK_EQ(hlo->opcode(), HloOpcode::kFusion);
+ emit_tiled_elemental_code_with_bounds_check(
+ output_index, "output", output_tile_bounds[2], output_tile_bounds[1],
+ [&](const IrArray::Index& index, llvm::Value* y_loc) {
+ GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_,
+ &ir_builder_, GetNestedComputer());
+ FusedIrEmitter fused_emitter(param_arrays, &elem_emitter);
+ tiled_param_info.set_y(y_loc);
+ fused_emitter.SetTiledParameterInfo(&tiled_param_info);
+ TF_CHECK_OK(hlo->fused_expression_root()->Accept(&fused_emitter));
+ IrArray::Index untiled_index = llvm_ir::GetUnreducedOutputIndex(
+ index, output_reduced_shapes[0], output_arrays[0].GetShape(),
+ &ir_builder_);
+ const llvm_ir::ElementGenerator& output_generator =
+ fused_emitter.GetRootGenerator();
+ llvm::Value* output_value =
+ output_generator(untiled_index).ValueOrDie();
+ if (hlo->IsMultiOutputFusion()) {
+ CHECK(output_value->getType()->isStructTy());
+ CHECK_EQ(output_value->getType()->getStructNumElements(),
+ output_in_reduced_shape_arrays.size());
+ for (int64 i = 0; i < output_in_reduced_shape_arrays.size(); ++i) {
+ output_in_reduced_shape_arrays[i].EmitWriteArrayElement(
+ index, ir_builder_.CreateExtractValue(output_value, i),
+ &ir_builder_);
+ }
+ } else {
+ output_in_reduced_shape_arrays[0].EmitWriteArrayElement(
+ index, output_value, &ir_builder_);
+ }
+ });
+ }
+
+ // For multioutput fusion, emit a tuple with all the individual outputs.
+ if (hlo->IsMultiOutputFusion()) {
+ std::vector<llvm::Value*> tuple_operand_ptrs;
+ for (int64 i = 0; i < output_arrays.size(); ++i) {
+ tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer());
+ }
+ llvm_ir::EmitTuple(GetIrArray(*hlo, *hlo), tuple_operand_ptrs, &ir_builder_,
+ module_);
+ }
+
+ return launch_dimensions;
+}
+
+bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) {
+ HloOpcode opcode = hlo->opcode();
+ CHECK(opcode == HloOpcode::kFusion || opcode == HloOpcode::kCopy);
+ CHECK(opcode != HloOpcode::kFusion ||
+ hlo->fusion_kind() == HloInstruction::FusionKind::kLoop)
+ << "Only loop fusions are supported.";
+
+ const Shape& output_shape = hlo->IsMultiOutputFusion()
+ ? ShapeUtil::GetSubshape(hlo->shape(), {0})
+ : hlo->shape();
+
+ // If the output_shape is reduced to 021 shape, find all the parameters of the
+ // hlo that are in the corresponding 012 shape.
+ std::vector<int64> params_012;
+ optional<std::vector<int64>> reduced_dims_021;
+ for (int64 operand_idx = 0; operand_idx < hlo->operand_count();
+ ++operand_idx) {
+ HloInstruction* operand = hlo->mutable_operand(operand_idx);
+ auto find_transpose_result =
+ llvm_ir::FindTranspose021(operand->shape(), output_shape);
+ if (!find_transpose_result.has_value()) {
+ continue;
+ }
+ const std::vector<int64>& curr_reduced_dims_021 = *find_transpose_result;
+ if (!reduced_dims_021.has_value()) {
+ reduced_dims_021 = curr_reduced_dims_021;
+ }
+ if (!ContainersEqual(*reduced_dims_021, curr_reduced_dims_021)) {
+ // There is more than one possible transpose. Instead of picking one
+ // transpose, we simply give up here.
+ return false;
+ }
+ params_012.push_back(operand_idx);
+ }
+
+ if (!reduced_dims_021.has_value()) {
+ return false;
+ }
+
+ if ((*reduced_dims_021)[1] < kMinDimensionToTransposeTiled ||
+ (*reduced_dims_021)[2] < kMinDimensionToTransposeTiled) {
+ return false;
+ }
+
+ VLOG(3) << "EmitHlo021Tile Emitting hlo tile 0-2-1" << hlo->ToString();
+ thunk_sequence_->emplace_back(
+ BuildKernelThunk(hlo, /*implements_whole_instruction=*/true));
+ const LaunchDimensions launch_dimensions =
+ EmitHlo021Tile(hlo, *reduced_dims_021, params_012);
+ UpdateLaunchDimensions(launch_dimensions, LastThunk(),
+ ir_emitter_context_->llvm_module());
+
+ return true;
+}
+
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
index e8dce1ca53..59547c16d7 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/ir_emitter.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h"
namespace xla {
namespace gpu {
@@ -73,6 +74,7 @@ class IrEmitterUnnested : public IrEmitter {
Status HandleTuple(HloInstruction* tuple) override;
Status HandleWhile(HloInstruction* xla_while) override;
Status HandleInfeed(HloInstruction* xla_infeed) override;
+ Status HandleOutfeed(HloInstruction* outfeed) override;
Status HandleRng(HloInstruction* random) override;
Status HandleSelect(HloInstruction* select) override;
Status HandleTupleSelect(HloInstruction* tuple_select) override;
@@ -116,7 +118,7 @@ class IrEmitterUnnested : public IrEmitter {
// Emits code that reduces a matrix of shape [height x width] to a vector of
// [width]. Other parameters have the same meaning as those of
// `EmitReductionToVector`. Note that input shape might not be
- // [height x width], but can be bitcast to [height x weight] with "height"
+ // [height x width], but can be bitcast to [height x width] with "height"
// being the major dimension.
Status EmitColumnReduction(
int64 height, int64 width, HloInstruction* reduce,
@@ -132,7 +134,7 @@ class IrEmitterUnnested : public IrEmitter {
// Emits code that reduces a 3D tensor of shape [depth x height x width] to a
// vector of shape [height]. Other parameters have the same meaning as those
// of `EmitReductionToVector`. Note that input shape might not be
- // [depth x height x width], but can be bitcast to [depth x height x weight]
+ // [depth x height x width], but can be bitcast to [depth x height x width]
// with "depth" being the most major dimension.
Status EmitRowReduction(
int64 depth, int64 height, int64 width, HloInstruction* reduce,
@@ -183,12 +185,56 @@ class IrEmitterUnnested : public IrEmitter {
std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
extra_output_gens);
+ // Returns true if a 0-2-1 tiling algorithm is already used to emit the kernel
+ // for the hlo instruction.
+ bool CheckAndEmitHloWithTile021(HloInstruction* hlo);
+ // Emits a kernel for the hlo instruction using a 0-2-1 tiling algorithm and
+ // returns the launch dimensions for the kernel. This is a helper to support
+ // the implementation of CheckAndEmitHloWithTile021.
+ LaunchDimensions EmitHlo021Tile(
+ HloInstruction* hlo,
+ tensorflow::gtl::ArraySlice<int64> reduced_output_dims,
+ tensorflow::gtl::ArraySlice<int64> tiled_param_ids);
+ // Generates the IrArray for each output of hlo and returns the number of
+ // outputs.
+ int ConstructIrArrayForOutputs(const HloInstruction& hlo,
+ std::vector<llvm_ir::IrArray>* output_arrays);
+ // Generates the IrArray for each input of hlo and returns the number of
+ // inputs.
+ int ConstructIrArrayForInputs(const HloInstruction& hlo,
+ std::vector<llvm_ir::IrArray>* param_arrays);
+ // For each output of the `hlo` instruction, constructs the reduced shape for
+ // the output with the given `reduced_output_dims` and cast the original
+ // output IrArray element in `output_arrays` to the reduced shape. Returns
+ // the number of outputs.
+ int ConstructOutputReducedShapeAndCastOutputIrArrayToShape(
+ const HloInstruction& hlo,
+ const std::vector<llvm_ir::IrArray>& output_arrays,
+ tensorflow::gtl::ArraySlice<int64> reduced_output_dims,
+ std::vector<Shape>* output_reduced_shapes,
+ std::vector<llvm_ir::IrArray>* output_in_reduced_shape_arrays);
+ // For each input of the `hlo` instruction, checks its value in
+ // `param_buffers` to find out whether the input has a reduced shape. If the
+ // input has a reduced shape, constructs the reduced shape for the input and
+ // casts the original input IrArray in `param_arrays` to the reduced shape.
+ // Return the total number of inputs.
+ int ConstructInputReducedShapeAndCastInputIrArrayToShape(
+ const HloInstruction& hlo,
+ const std::vector<llvm_ir::IrArray>& param_arrays,
+ const std::vector<llvm::Value*>& param_buffers,
+ tensorflow::gtl::ArraySlice<int64> reduced_output_dims,
+ std::vector<Shape>* param_reduced_shapes,
+ std::vector<llvm_ir::IrArray>* param_in_reduced_shape_arrays);
+
// Returns a KernelThunk that invokes the kernel emitted for `inst`. The
// caller needs to make sure `inst` outlives the lifetime of the returned
// Thunk object. The kernel implementation will be unrolled if unroll_factor
- // is greater than one.
- std::unique_ptr<KernelThunk> BuildKernelThunk(const HloInstruction* inst,
- int unroll_factor = 1);
+ // is greater than one. 'implements_whole_instruction' specifies whether this
+ // KernelThunk implements the whole 'inst' HloInstruction. In some cases
+ // 'inst' will be implemented by a sequence of Thunks.
+ std::unique_ptr<KernelThunk> BuildKernelThunk(
+ const HloInstruction* inst, bool implements_whole_instruction,
+ int unroll_factor = 1);
// Returns a FftThunk that calls cuFFT to implement `inst`.
std::unique_ptr<Thunk> BuildFftThunk(const HloInstruction* inst);
@@ -209,10 +255,14 @@ class IrEmitterUnnested : public IrEmitter {
std::unique_ptr<Thunk> BuildDeviceToDeviceCopyThunk(
const HloInstruction* inst);
- // Returns an InfeedThunk that performs device-to-device memcpy to implement
+ // Returns an InfeedThunk that performs a host-to-device memcpy to implement
// `inst`.
std::unique_ptr<Thunk> BuildInfeedThunk(const HloInstruction* inst);
+ // Returns an OutfeedThunk that performs a device-to-host memcpy to implement
+ // `inst`.
+ std::unique_ptr<Thunk> BuildOutfeedThunk(const HloInstruction* inst);
+
// Returns a WhileThunk that invokes thunk sequences for 'condition' and
// 'body' sub-computations of while instruction 'hlo'.
std::unique_ptr<Thunk> BuildWhileThunk(const HloInstruction* hlo);
diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc b/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc
new file mode 100644
index 0000000000..4aaf0c9e14
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc
@@ -0,0 +1,32 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/outfeed_manager.h"
+
+#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+namespace gpu {
+
+OutfeedManager* GetOrCreateOutfeedManager() {
+ static auto* manager = new OutfeedManager;
+ return manager;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_manager.h b/tensorflow/compiler/xla/service/gpu/outfeed_manager.h
new file mode 100644
index 0000000000..a752eb7011
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/outfeed_manager.h
@@ -0,0 +1,69 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OUTFEED_MANAGER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OUTFEED_MANAGER_H_
+
+#include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/compiler/xla/service/gpu/xfeed_queue.h"
+#include "tensorflow/compiler/xla/shape_tree.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/notification.h"
+
+namespace xla {
+namespace gpu {
+
+// TODO(b/30467474) Once GPU outfeed implementation settles, consider
+// folding back the cpu and gpu outfeed implementations into a generic
+// one if possible.
+
+// Defines a buffer holding the destination for an outfeed in host memory and a
+// notification when that triggers when the transfer is done.
+class OutfeedBuffer {
+ public:
+ OutfeedBuffer(int64 length) : length_(length) {}
+
+ // Waits for the device transfer to be finished.
+ std::unique_ptr<Literal> WaitUntilAvailable() {
+ done_.WaitForNotification();
+ return std::move(destination_);
+ }
+
+ int64 length() const { return length_; }
+ void set_destination(std::unique_ptr<Literal> destination) {
+ destination_ = std::move(destination);
+ }
+ Literal* destination() { return destination_.get(); }
+
+ // Callback to signal that this buffer is consumed.
+ void Done() { done_.Notify(); }
+
+ private:
+ std::unique_ptr<Literal> destination_;
+ const int64 length_;
+ tensorflow::Notification done_;
+};
+
+// Manages a thread-safe queue of buffers. The buffers are supposed to be
+// produced by the transfer manager and consumed by the device.
+using OutfeedManager = XfeedQueue<ShapeTree<std::unique_ptr<OutfeedBuffer>>*>;
+
+// Singleton creator-or-accessor: Returns the GPU outfeed manager.
+OutfeedManager* GetOrCreateOutfeedManager();
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OUTFEED_MANAGER_H_
diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc
new file mode 100644
index 0000000000..7986e63f43
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc
@@ -0,0 +1,111 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/outfeed_thunk.h"
+#include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
+#include "tensorflow/compiler/xla/service/gpu/outfeed_manager.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace xla {
+namespace gpu {
+
+OutfeedThunk::OutfeedThunk(ShapeTree<BufferAllocation::Slice> outfeed_slices,
+ const HloInstruction* hlo_instruction)
+ : Thunk(Kind::kOutfeed, hlo_instruction),
+ outfeed_slices_(std::move(outfeed_slices)) {}
+
+Status OutfeedThunk::ExecuteOnStream(
+ const BufferAllocations& buffer_allocations, se::Stream* stream,
+ HloExecutionProfiler* profiler) {
+ VLOG(2) << "Outfeeding from GPU: " << hlo_instruction()->ToString();
+
+ auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
+ OutfeedManager* outfeed_manager = GetOrCreateOutfeedManager();
+ ShapeTree<std::unique_ptr<OutfeedBuffer>>* outfeed_buffers =
+ outfeed_manager->BlockingGetNextDestination();
+
+ // Nothing to be done for empty tuples.
+ if (ShapeUtil::IsEmptyTuple(hlo_instruction()->operand(0)->shape())) {
+ return Status::OK();
+ }
+ CHECK(ShapeUtil::Compatible(hlo_instruction()->operand(0)->shape(),
+ outfeed_buffers->shape()));
+
+ TF_RETURN_IF_ERROR(outfeed_buffers->ForEachMutableElementWithStatus(
+ [&](const ShapeIndex& index, std::unique_ptr<OutfeedBuffer>* buffer) {
+ if (!*buffer) { // Tuple pointers.
+ return Status::OK();
+ }
+ // Allocate storage for the literal data.
+ const Shape& shape =
+ ShapeUtil::GetSubshape(outfeed_buffers->shape(), index);
+ (*buffer)->set_destination(Literal::CreateFromShape(shape));
+
+ BufferAllocation::Slice slice = outfeed_slices_.element(index);
+ se::DeviceMemoryBase data_address;
+ if (slice.allocation()) {
+ // If we have a static allocation, read it from there. This avoids
+ // synchronizing the host and device just to read a pointer.
+ data_address = buffer_allocations.GetDeviceAddress(slice);
+ } else {
+ // Otherwise we have to read the tuple pointer first.
+ CHECK(!index.empty());
+ // Copy the parent buffer to the host.
+ BufferAllocation::Slice tuple_slice =
+ outfeed_slices_.element(ShapeIndexView(index).ConsumeFront());
+ if (!tuple_slice.allocation()) {
+ return Unimplemented(
+ "Nested dynamic tuples are not supported on GPU");
+ }
+ se::DeviceMemoryBase tuple_address =
+ buffer_allocations.GetDeviceAddress(tuple_slice);
+ CHECK(tuple_slice.size() % sizeof(void*) == 0)
+ << "Tuple size must be a multiple of pointer size";
+ std::vector<void*> tuple_element_buffer_addresses(tuple_slice.size() /
+ sizeof(void*));
+ stream->ThenMemcpy(tuple_element_buffer_addresses.data(),
+ tuple_address, tuple_slice.size());
+ TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
+ // The data address is specified by the element of the tuple pointer
+ // buffer.
+ data_address =
+ se::DeviceMemoryBase(tuple_element_buffer_addresses[index.back()],
+ (*buffer)->length());
+ }
+
+ // TODO(b/111309141): Run this on a separate stream so it doesn't block
+ // the GPU from doing work during the transfer. This could be handled by
+ // making StreamAssignment do something intelligent with outfeed thunks.
+ stream
+ ->ThenMemcpy((*buffer)->destination()->untyped_data(), data_address,
+ (*buffer)->length())
+ .ThenDoHostCallback([buffer]() { (*buffer)->Done(); });
+ return Status::OK();
+ }));
+
+ Status block_status = stream->BlockHostUntilDone();
+ if (!block_status.ok()) {
+ return InternalError("Failed to complete data transfer on stream %p: %s",
+ stream, block_status.error_message().c_str());
+ }
+
+ VLOG(2) << "Outfeeding from GPU complete";
+ return Status::OK();
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.h b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.h
new file mode 100644
index 0000000000..8ed89f05f0
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.h
@@ -0,0 +1,52 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OUTFEED_THUNK_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OUTFEED_THUNK_H_
+
+#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
+#include "tensorflow/compiler/xla/service/gpu/thunk.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace xla {
+namespace gpu {
+
+// A thunk that outfeeds data. Data must be already resident on the host. This
+// thunk performs a host to device copy from the buffer allocated for the
+// outfeed op to the host location.
+class OutfeedThunk : public Thunk {
+ public:
+ // Constructs a OutfeedThunk that copies data to the host-side
+ // outfeed queue from the buffers in the given shape tree.
+ OutfeedThunk(ShapeTree<BufferAllocation::Slice> outfeed_slices,
+ const HloInstruction* hlo_instruction);
+
+ OutfeedThunk(const OutfeedThunk&) = delete;
+ OutfeedThunk& operator=(const OutfeedThunk&) = delete;
+
+ Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) override;
+
+ private:
+ const ShapeTree<BufferAllocation::Slice> outfeed_slices_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OUTFEED_THUNK_H_
diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
index c8f0d4185c..b22040eee1 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/pad_insertion.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
@@ -68,7 +69,7 @@ HloInstruction* MaybePaddedAndSlicedInput(
PrimitiveType element_type = input->shape().element_type();
HloInstruction* padding =
computation->AddInstruction(HloInstruction::CreateConstant(
- MakeUnique<Literal>(Literal::Zero(element_type))));
+ MakeUnique<Literal>(LiteralUtil::Zero(element_type))));
input = MakePadHlo(input, padding, padding_config).ValueOrDie();
}
@@ -125,7 +126,7 @@ HloInstruction* MaybePaddedKernel(const Window& conv_window,
PrimitiveType element_type = kernel->shape().element_type();
HloInstruction* padding =
computation->AddInstruction(HloInstruction::CreateConstant(
- MakeUnique<Literal>(Literal::Zero(element_type))));
+ MakeUnique<Literal>(LiteralUtil::Zero(element_type))));
return MakePadHlo(kernel, padding, padding_config).ValueOrDie();
}
} // namespace
@@ -234,9 +235,9 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution(
// Create a new backward convolution replacing the old one.
HloComputation* computation = backward_conv->parent();
HloInstruction* output = backward_conv->mutable_operand(1);
- HloInstruction* padding =
- computation->AddInstruction(HloInstruction::CreateConstant(
- MakeUnique<Literal>(Literal::Zero(input->shape().element_type()))));
+ HloInstruction* padding = computation->AddInstruction(
+ HloInstruction::CreateConstant(MakeUnique<Literal>(
+ LiteralUtil::Zero(input->shape().element_type()))));
HloInstruction* padded_input =
MakePadHlo(input, padding, input_padding_config).ValueOrDie();
diff --git a/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc b/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc
index dfdba7d7d9..84285be70a 100644
--- a/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc
@@ -36,12 +36,7 @@ Status SequentialThunk::Initialize(const GpuExecutable& executable,
Status SequentialThunk::ExecuteOnStream(
const BufferAllocations& buffer_allocations, se::Stream* stream,
HloExecutionProfiler* profiler) {
- // TODO(b/71544591): We need to potentially measure the total time of the
- // sequential thunk. This happens for a reduce op which consists of
- // SequentialThunk with a thunk that initializes the output, and another thunk
- // that does the actual reduce. Right now, in this case we would only measure
- // the time of the last thunk, because both thunks would have the same
- // HloInstruction.
+ auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
for (const auto& thunk : thunks_) {
TF_RETURN_IF_ERROR(
thunk->ExecuteOnStream(buffer_allocations, stream, profiler));
diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h
index 14d41033c2..99a1a0eae9 100644
--- a/tensorflow/compiler/xla/service/gpu/thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/thunk.h
@@ -54,6 +54,7 @@ class Thunk {
kKernel,
kMemset32BitValue,
kMemzero,
+ kOutfeed,
kSequential,
kTuple,
kWhile,
diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.cc b/tensorflow/compiler/xla/service/gpu/while_thunk.cc
index 5e13f989c2..1315a4183a 100644
--- a/tensorflow/compiler/xla/service/gpu/while_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/while_thunk.cc
@@ -30,10 +30,14 @@ WhileThunk::WhileThunk(
const HloInstruction* hlo)
: Thunk(Kind::kWhile, hlo),
condition_result_buffer_index_(condition_result_buffer_index),
+ // Pass nullptr as the HloInstruction* to the condition_thunk_sequence_
+ // and body_thunk_sequence_ constructors because these SequentialThunks
+ // are logically "part of" this WhileThunk, and shouldn't be profiled
+ // separately from it.
condition_thunk_sequence_(MakeUnique<SequentialThunk>(
- std::move(*condition_thunk_sequence), hlo)),
- body_thunk_sequence_(
- MakeUnique<SequentialThunk>(std::move(*body_thunk_sequence), hlo)) {}
+ std::move(*condition_thunk_sequence), nullptr)),
+ body_thunk_sequence_(MakeUnique<SequentialThunk>(
+ std::move(*body_thunk_sequence), nullptr)) {}
Status WhileThunk::Initialize(const GpuExecutable& executable,
se::StreamExecutor* executor) {
diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.cc b/tensorflow/compiler/xla/service/gpu/while_transformer.cc
index 7749201cbc..c5321df6c4 100644
--- a/tensorflow/compiler/xla/service/gpu/while_transformer.cc
+++ b/tensorflow/compiler/xla/service/gpu/while_transformer.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include <unordered_map>
#include <vector>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
index 2f290f61bd..dbc8442ed2 100644
--- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
@@ -42,7 +42,7 @@ class WhileTransformerTest : public HloTestBase {
const int64 tuple_index, const int64 limit) {
auto builder = HloComputation::Builder(TestName() + ".Condition");
auto limit_const = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(limit)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(limit)));
auto loop_state = builder.AddInstruction(HloInstruction::CreateParameter(
0, GetLoopStateShape(tuple_index), "loop_state"));
auto induction_variable =
@@ -65,8 +65,8 @@ class WhileTransformerTest : public HloTestBase {
auto induction_variable =
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
induction_variable_shape_, loop_state, ind_var_tuple_index));
- auto inc = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(increment)));
+ auto inc = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR0<int32>(increment)));
auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc));
// Update data GTE(data_tuple_index).
@@ -89,10 +89,12 @@ class WhileTransformerTest : public HloTestBase {
const int64 ind_var_tuple_index,
const int64 ind_var_init) {
auto builder = HloComputation::Builder(TestName() + ".While");
- auto induction_var_init = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(ind_var_init)));
- auto data_init = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f})));
+ auto induction_var_init =
+ builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR0<int32>(ind_var_init)));
+ auto data_init = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
+ {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f})));
auto loop_state_init =
ind_var_tuple_index == 0
? builder.AddInstruction(
diff --git a/tensorflow/compiler/xla/service/gpu/xfeed_queue.h b/tensorflow/compiler/xla/service/gpu/xfeed_queue.h
new file mode 100644
index 0000000000..737c7eb025
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/xfeed_queue.h
@@ -0,0 +1,89 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_XFEED_QUEUE_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_XFEED_QUEUE_H_
+
+#include <deque>
+#include <vector>
+
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/notification.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+
+namespace xla {
+namespace gpu {
+
+// TODO(b/30467474) Once GPU outfeed implementation settles, consider
+// folding back the cpu and gpu outfeed implementations into a generic
+// one if possible.
+
+// Manages a thread-safe queue of buffers.
+template <typename BufferType>
+class XfeedQueue {
+ public:
+ // Adds a tree of buffers to the queue. The individual buffers correspond to
+ // the elements of a tuple and may be nullptr if the buffer is a tuple index
+ // buffer.
+ void EnqueueDestination(BufferType buffers) {
+ tensorflow::mutex_lock l(mu_);
+ enqueued_buffers_.push_back(std::move(buffers));
+ cv_.notify_one();
+ }
+
+ // Blocks until the queue is non-empty, then returns the buffer at the head of
+ // the queue.
+ BufferType BlockingGetNextDestination() {
+ bool became_empty;
+ BufferType current_buffer;
+ {
+ tensorflow::mutex_lock l(mu_);
+ while (enqueued_buffers_.empty()) {
+ cv_.wait(l);
+ }
+ current_buffer = std::move(enqueued_buffers_.front());
+ enqueued_buffers_.pop_front();
+ became_empty = enqueued_buffers_.empty();
+ }
+ if (became_empty) {
+ for (const auto& callback : on_empty_callbacks_) {
+ callback();
+ }
+ }
+ return current_buffer;
+ }
+
+ void RegisterOnEmptyCallback(std::function<void()> callback) {
+ on_empty_callbacks_.push_back(std::move(callback));
+ }
+
+ private:
+ tensorflow::mutex mu_;
+
+ // Condition variable that is signaled every time a buffer is enqueued.
+ tensorflow::condition_variable cv_;
+
+ // The queue of trees of buffers. Buffer* queue contents are not owned.
+ std::deque<BufferType> enqueued_buffers_ GUARDED_BY(mu_);
+
+ // List of callbacks which will be called when 'enqueued_buffers_' becomes
+ // empty.
+ std::vector<std::function<void()>> on_empty_callbacks_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_XFEED_QUEUE_H_
diff --git a/tensorflow/compiler/xla/service/graphviz_example.cc b/tensorflow/compiler/xla/service/graphviz_example.cc
index acf6611486..aa89567ee8 100644
--- a/tensorflow/compiler/xla/service/graphviz_example.cc
+++ b/tensorflow/compiler/xla/service/graphviz_example.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include <memory>
#include <string>
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -47,7 +48,7 @@ HloComputation* AddScalarConstantComputation(int64 addend, HloModule* module) {
auto x_value = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {}), "x_value"));
auto half = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.5)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.5)));
builder.AddInstruction(HloInstruction::CreateBinary(
half->shape(), HloOpcode::kAdd, x_value, half));
return module->AddEmbeddedComputation(builder.Build());
@@ -122,7 +123,7 @@ std::unique_ptr<HloModule> MakeBigGraph() {
auto rng = builder.AddInstruction(
HloInstruction::CreateRng(vshape, RNG_UNIFORM, {param_m, param_m}));
auto one = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto add_computation = ScalarSumComputation(module.get());
builder.AddInstruction(
HloInstruction::CreateReduce(vshape, rng, one, {1}, add_computation));
diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc
index 3849b565e3..b41dc66fe9 100644
--- a/tensorflow/compiler/xla/service/heap_simulator_test.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include <utility>
#include <vector>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -239,7 +239,7 @@ class HeapSimulatorTest : public HloTestBase {
TEST_F(HeapSimulatorTest, ScalarConstant) {
auto builder = HloComputation::Builder(TestName());
auto const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
// Constants aren't assigned. See b/32248867
HeapSimulatorTracker tracker(TestName(), builder.Build(), {const0});
@@ -674,7 +674,7 @@ class HeapAlgorithmTestBase : public ::testing::Test {
const BufferValue* DummyBufferValue() {
const BufferValue::Id id = buffers_.size();
auto const0 = builder_.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
buffers_.emplace_back(MakeUnique<HloValue>(id, const0, ShapeIndex{}));
return buffers_.back().get();
}
diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc
index a59bf1750c..403d4df6b5 100644
--- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include <map>
#include <memory>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
@@ -116,9 +116,9 @@ TEST_F(HloAliasAnalysisTest, BinaryOperation) {
// Test the analysis on a single binary operation (Add).
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
scalar_shape_, HloOpcode::kAdd, constant1, constant2));
module_->AddEntryComputation(builder.Build());
@@ -228,9 +228,9 @@ TEST_F(HloAliasAnalysisTest, SingleCall) {
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto call = builder.AddInstruction(HloInstruction::CreateCall(
scalar_shape_, {constant1, constant2}, called_computation));
module_->AddEntryComputation(builder.Build());
@@ -267,9 +267,9 @@ TEST_F(HloAliasAnalysisTest, ComputationCalledTwice) {
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto call1 = builder.AddInstruction(HloInstruction::CreateCall(
scalar_shape_, {constant1, constant2}, called_computation));
auto call2 = builder.AddInstruction(HloInstruction::CreateCall(
@@ -346,15 +346,15 @@ TEST_F(HloAliasAnalysisTest, SingleWhile) {
auto cond_param = cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "param"));
cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
HloComputation* condition =
module_->AddEmbeddedComputation(cond_builder.Build());
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto xla_while = builder.AddInstruction(
@@ -439,15 +439,15 @@ TEST_F(HloAliasAnalysisTest, SequentialWhiles) {
cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "param"));
cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
HloComputation* condition =
module_->AddEmbeddedComputation(cond_builder.Build());
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto xla_while0 = builder.AddInstruction(
@@ -498,7 +498,7 @@ TEST_F(HloAliasAnalysisTest, NestedWhiles) {
cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "param"));
cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
return cond_builder.Build();
};
// Build separate condition computations so the call graph is flat. The
@@ -543,9 +543,9 @@ TEST_F(HloAliasAnalysisTest, NestedWhiles) {
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto entry_while = builder.AddInstruction(
@@ -608,17 +608,17 @@ TEST_F(HloAliasAnalysisTest, SwizzlingWhile) {
cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "param"));
auto cond_constant = cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
HloComputation* condition =
module_->AddEmbeddedComputation(cond_builder.Build());
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto constant3 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2, constant3}));
auto xla_while = builder.AddInstruction(
@@ -657,15 +657,15 @@ TEST_F(HloAliasAnalysisTest, TupleSelect) {
// Test a kTupleSelect. Non-top-level element flow through the instruction.
auto builder = HloComputation::Builder(TestName());
auto pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto constant3 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
auto constant4 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(4.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(4.0)));
auto tuple1 =
builder.AddInstruction(HloInstruction::CreateTuple({constant1}));
auto tuple2 =
@@ -753,16 +753,16 @@ TEST_F(HloAliasAnalysisTest, TupleSelectToWhile) {
auto cond_param = cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "param"));
cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
HloComputation* condition =
module_->AddEmbeddedComputation(cond_builder.Build());
auto pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto tuple1 =
builder.AddInstruction(HloInstruction::CreateTuple({constant1}));
auto tuple2 =
@@ -805,7 +805,7 @@ TEST_F(HloAliasAnalysisTest, Bitcast) {
// Bitcasting a value should not produce a new buffer.
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary(
scalar_shape_, HloOpcode::kBitcast, constant));
@@ -824,7 +824,7 @@ TEST_F(HloAliasAnalysisTest, BitcastInterference) {
// interference.
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary(
scalar_shape_, HloOpcode::kBitcast, constant));
builder.AddInstruction(HloInstruction::CreateTuple({constant, bitcast}));
@@ -843,13 +843,13 @@ TEST_F(HloAliasAnalysisTest, WhileInterference) {
// the other use of the init.
auto builder = HloComputation::Builder(TestName());
auto init = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto cond_builder = HloComputation::Builder("condition");
auto cond_param = cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, init->shape(), "param"));
auto cond_root = cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
HloComputation* condition =
module_->AddEmbeddedComputation(cond_builder.Build());
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index e36bef60a3..166a83fade 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -528,8 +528,10 @@ HloInstruction* HloComputation::CreateFusionInstruction(
}
StatusOr<HloInstruction*> HloComputation::DeepCopyHelper(
- HloInstruction* instruction, const ShapeTree<bool>* indices_to_copy,
- ShapeTree<HloInstruction*>* copies_added, ShapeIndex* index) {
+ HloInstruction* instruction, ShapeIndex* index,
+ const std::function<
+ HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
+ HloComputation* computation)>& copy_leaf) {
if (ShapeUtil::IsTuple(instruction->shape())) {
std::vector<HloInstruction*> elements;
for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape());
@@ -540,9 +542,8 @@ StatusOr<HloInstruction*> HloComputation::DeepCopyHelper(
instruction, i));
index->push_back(i);
- TF_ASSIGN_OR_RETURN(
- HloInstruction * element,
- DeepCopyHelper(gte, indices_to_copy, copies_added, index));
+ TF_ASSIGN_OR_RETURN(HloInstruction * element,
+ DeepCopyHelper(gte, index, copy_leaf));
elements.push_back(element);
index->pop_back();
}
@@ -556,19 +557,7 @@ StatusOr<HloInstruction*> HloComputation::DeepCopyHelper(
// Array shape.
TF_RET_CHECK(ShapeUtil::IsArray(instruction->shape()));
- if (indices_to_copy == nullptr || indices_to_copy->element(*index)) {
- // Use kCopy to copy array elements
- HloInstruction* copy = AddInstruction(HloInstruction::CreateUnary(
- instruction->shape(), HloOpcode::kCopy, instruction));
- if (copies_added != nullptr) {
- *copies_added->mutable_element(*index) = copy;
- }
- return copy;
- } else {
- // Elements which are not to be copied are passed through
- // transparently.
- return instruction;
- }
+ return copy_leaf(instruction, *index, this);
}
StatusOr<HloInstruction*> HloComputation::DeepCopyInstruction(
@@ -590,7 +579,36 @@ StatusOr<HloInstruction*> HloComputation::DeepCopyInstruction(
}
ShapeIndex index;
- return DeepCopyHelper(instruction, indices_to_copy, copies_added, &index);
+ auto copy_leaf = [indices_to_copy, copies_added](
+ HloInstruction* leaf, const ShapeIndex& leaf_index,
+ HloComputation* computation) {
+ if (indices_to_copy == nullptr || indices_to_copy->element(leaf_index)) {
+ HloInstruction* copy = computation->AddInstruction(
+ HloInstruction::CreateUnary(leaf->shape(), HloOpcode::kCopy, leaf));
+ if (copies_added != nullptr) {
+ *copies_added->mutable_element(leaf_index) = copy;
+ }
+ return copy;
+ }
+ // Elements which are not to be copied are passed through
+ // transparently.
+ return leaf;
+ };
+ return DeepCopyHelper(instruction, &index, copy_leaf);
+}
+
+StatusOr<HloInstruction*> HloComputation::DeepCopyInstructionWithCustomCopier(
+ HloInstruction* instruction,
+ const std::function<
+ HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
+ HloComputation* computation)>& copy_leaf) {
+ if (instruction->parent() != this) {
+ return FailedPrecondition(
+ "Can't deep copy instruction %s: instruction is not in computation %s",
+ instruction->name().c_str(), name().c_str());
+ }
+ ShapeIndex index;
+ return DeepCopyHelper(instruction, &index, copy_leaf);
}
ProgramShape HloComputation::ComputeProgramShape() const {
@@ -663,7 +681,7 @@ std::unique_ptr<HloReachabilityMap> HloComputation::ComputeReachability()
inputs.assign(hlo->operands().begin(), hlo->operands().end());
inputs.insert(inputs.end(), hlo->control_predecessors().begin(),
hlo->control_predecessors().end());
- result->SetReachabilityToUnion(inputs, hlo);
+ result->FastSetReachabilityToUnion(inputs, hlo);
}
return result;
}
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index c1c3e79ebc..abc1da4da3 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_
+#include <functional>
#include <list>
#include <memory>
#include <string>
@@ -254,6 +255,14 @@ class HloComputation {
const ShapeTree<bool>* indices_to_copy = nullptr,
ShapeTree<HloInstruction*>* copies_added = nullptr);
+ // As above, but uses a custom function to copy the leaf nodes, which could
+ // create alternative HLOs other than kCopy, or even pass-throughs.
+ StatusOr<HloInstruction*> DeepCopyInstructionWithCustomCopier(
+ HloInstruction* instruction,
+ const std::function<
+ HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
+ HloComputation* computation)>& copy_leaf);
+
// Computes and returns the ProgramShape of this computation (shape of
// parameters and result with layout).
ProgramShape ComputeProgramShape() const;
@@ -378,8 +387,10 @@ class HloComputation {
// Internal helper for recursive copying of an instruction. Creates and
// returns a deep copy of the given instruction.
StatusOr<HloInstruction*> DeepCopyHelper(
- HloInstruction* instruction, const ShapeTree<bool>* indices_to_copy,
- ShapeTree<HloInstruction*>* copies_added, ShapeIndex* index);
+ HloInstruction* instruction, ShapeIndex* index,
+ const std::function<
+ HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
+ HloComputation* computation)>& copy_leaf);
// Internal helper to collect unreachable roots.
std::vector<HloInstruction*> CollectUnreachableRoots() const;
diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc
index a8f3f0e9c2..e4c5470331 100644
--- a/tensorflow/compiler/xla/service/hlo_computation_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <set>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
@@ -118,7 +118,7 @@ TEST_F(HloComputationTest, PostOrderSingleton) {
// Test GetInstructionPostOrder for a computation with one instruction.
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->MakeInstructionPostOrder(), ElementsAre(constant));
@@ -129,7 +129,7 @@ TEST_F(HloComputationTest, PostOrderSimple) {
// instructions.
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
auto negate1 = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant));
auto negate2 = builder.AddInstruction(
@@ -144,7 +144,7 @@ TEST_F(HloComputationTest, PostOrderTrace) {
// Test GetInstructionPostOrder for a computation with a trace instruction.
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
auto negate1 = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant));
auto trace =
@@ -163,13 +163,13 @@ TEST_F(HloComputationTest, PostOrderDisconnectedInstructions) {
// which are not connected.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
auto constant3 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
auto constant4 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->MakeInstructionPostOrder(),
@@ -181,11 +181,11 @@ TEST_F(HloComputationTest, PostOrderWithMultipleRoots) {
// which are not connected.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
auto constant3 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
r0f32_, HloOpcode::kAdd, constant1, constant2));
auto add2 = builder.AddInstruction(HloInstruction::CreateBinary(
@@ -205,11 +205,11 @@ TEST_F(HloComputationTest, VisitWithMultipleRoots) {
// computation has multiple roots (dead code).
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
auto constant3 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
// Add three disconnected add expressions.
builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
constant1, constant2));
@@ -256,7 +256,7 @@ TEST_F(HloComputationTest, DeepCopyArray) {
// Test that DeepCopyInstruction properly copies an array.
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({1.0, 2.0, 3.0})));
+ LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0})));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
auto copy = computation->DeepCopyInstruction(constant).ValueOrDie();
@@ -268,9 +268,9 @@ TEST_F(HloComputationTest, DeepCopyTuple) {
// Test that DeepCopyInstruction properly copies a tuple.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({1.0, 2.0, 3.0})));
+ LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0})));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
@@ -289,7 +289,7 @@ TEST_F(HloComputationTest, DeepCopyArrayAtIndices) {
// copy are specified.
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({1.0, 2.0, 3.0})));
+ LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0})));
auto computation = builder.Build();
{
@@ -314,9 +314,9 @@ TEST_F(HloComputationTest, DeepCopyTupleAtIndices) {
// specified by the given indices.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({1.0, 2.0, 3.0})));
+ LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0})));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto computation = builder.Build();
@@ -375,7 +375,7 @@ TEST_F(HloComputationTest, DeepCopyToken) {
// Test that DeepCopyInstruction properly handles tokens which should not be
// copied.
auto builder = HloComputation::Builder(TestName());
- auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({}));
+ auto token = builder.AddInstruction(HloInstruction::CreateToken());
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
auto copy = computation->DeepCopyInstruction(token).ValueOrDie();
@@ -388,9 +388,9 @@ TEST_F(HloComputationTest, DeepCopyTokenTuple) {
// Test that DeepCopyInstruction properly handles tokens which should not be
// copied.
auto builder = HloComputation::Builder(TestName());
- auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({}));
+ auto token = builder.AddInstruction(HloInstruction::CreateToken());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
auto tuple =
builder.AddInstruction(HloInstruction::CreateTuple({token, constant}));
auto module = CreateNewModule();
@@ -407,7 +407,7 @@ TEST_F(HloComputationTest, CycleDetection) {
// Test whether the visitor can detect cycles in the graph.
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
auto negate = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant));
auto add = builder.AddInstruction(
@@ -433,7 +433,7 @@ TEST_F(HloComputationTest, RemoveInstructionWithDuplicateOperand) {
// twice.
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
auto dead_negate = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant));
auto dead_add = builder.AddInstruction(HloInstruction::CreateBinary(
@@ -456,9 +456,9 @@ TEST_F(HloComputationTest, RemoveInstructionWithDuplicateOperand) {
TEST_F(HloComputationTest, CloneWithControlDependency) {
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0f)));
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
r0f32_, HloOpcode::kAdd, constant1, constant2));
@@ -502,9 +502,9 @@ TEST_F(HloComputationTest, Reachability) {
// There is a control dependency from 'add' to 'exp'.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0f)));
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
r0f32_, HloOpcode::kAdd, constant1, constant2));
auto negate = builder.AddInstruction(
@@ -607,13 +607,14 @@ TEST_F(HloComputationTest, Stringification) {
auto* computation = module->AddEntryComputation(builder.Build());
auto options = HloPrintOptions().set_print_metadata(false);
- EXPECT_EQ(computation->ToString(options),
- R"(%TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] {
+ const string expected_computation =
+ R"(%TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] {
%x = f32[5,10]{1,0} parameter(0)
%y = f32[20,10]{1,0} parameter(1)
%transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0}
ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})");
+})";
+ EXPECT_EQ(computation->ToString(options), expected_computation);
}
TEST_F(HloComputationTest, StringificationIndent) {
@@ -639,13 +640,14 @@ TEST_F(HloComputationTest, StringificationIndent) {
auto options =
HloPrintOptions().set_print_metadata(false).set_indent_amount(2);
- EXPECT_EQ(computation->ToString(options),
- R"( %TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] {
+ const string expected_computation =
+ R"( %TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] {
%x = f32[5,10]{1,0} parameter(0)
%y = f32[20,10]{1,0} parameter(1)
%transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0}
ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- })");
+ })";
+ EXPECT_EQ(computation->ToString(options), expected_computation);
}
TEST_F(HloComputationTest, StringificationCanonical) {
@@ -670,21 +672,23 @@ TEST_F(HloComputationTest, StringificationCanonical) {
auto* computation = module->AddEntryComputation(builder.Build());
auto options = HloPrintOptions().set_print_metadata(false);
- EXPECT_EQ(computation->ToString(options),
- R"(%TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] {
+ const string expected_computation1 =
+ R"(%TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] {
%x = f32[5,10]{1,0} parameter(0)
%y = f32[20,10]{1,0} parameter(1)
%transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0}
ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})");
+})";
+ EXPECT_EQ(computation->ToString(options), expected_computation1);
options = HloPrintOptions().Canonical();
- EXPECT_EQ(computation->ToString(options), R"(TransposeDot {
+ const string expected_computation2 = R"(TransposeDot {
tmp_0 = f32[5,10]{1,0} parameter(0)
tmp_1 = f32[20,10]{1,0} parameter(1)
tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})");
+})";
+ EXPECT_EQ(computation->ToString(options), expected_computation2);
}
} // namespace
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc
index 436d103f23..7229031c0c 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
index 5d05ccfc0b..64a42c1efc 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include <utility>
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
@@ -41,7 +41,7 @@ using HloConstantFoldingTest = HloTestBase;
TEST_F(HloConstantFoldingTest, ConvertF32ToS64) {
HloComputation::Builder builder(TestName());
HloInstruction* input = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
builder.AddInstruction(
HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {}), input));
@@ -62,7 +62,7 @@ TEST_F(HloConstantFoldingTest, ConvertF32ToS64) {
TEST_F(HloConstantFoldingTest, ConvertS64ToF32) {
HloComputation::Builder builder(TestName());
HloInstruction* input = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int64>(42)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int64>(42)));
builder.AddInstruction(
HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input));
@@ -82,8 +82,8 @@ TEST_F(HloConstantFoldingTest, ConvertS64ToF32) {
TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) {
HloComputation::Builder builder(TestName());
- HloInstruction* input = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>({42.0f, 19.0f})));
+ HloInstruction* input = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<float>({42.0f, 19.0f})));
builder.AddInstruction(
HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {2}), input));
@@ -120,7 +120,7 @@ TEST_F(HloConstantFoldingTest, Concatenate) {
for (auto csize : test_config.concat_sizes) {
dimensions[test_config.concat_dimension] = csize;
concat_size += csize;
- auto literal = Literal::CreateFromDimensions(F32, dimensions);
+ auto literal = LiteralUtil::CreateFromDimensions(F32, dimensions);
HloInstruction* insn = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
operands.push_back(insn);
@@ -149,7 +149,7 @@ TEST_F(HloConstantFoldingTest, Slice) {
const int64 slice_limits[] = {10, 8, 6, 5, 9};
const int64 slice_strides[] = {1, 1, 1, 1, 1};
TF_ASSERT_OK_AND_ASSIGN(auto literal,
- Literal::CreateRandomLiteral<F32>(
+ LiteralUtil::CreateRandomLiteral<F32>(
ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
HloInstruction* literal_instruction = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
@@ -172,7 +172,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
HloComputation::Builder builder(TestName());
const int64 dimensions[] = {11, 8, 7, 5, 9};
TF_ASSERT_OK_AND_ASSIGN(auto literal,
- Literal::CreateRandomLiteral<F32>(
+ LiteralUtil::CreateRandomLiteral<F32>(
ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
auto literal_clone = literal->Literal::CloneToUnique();
HloInstruction* literal_instruction = builder.AddInstruction(
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
index 9fc4c48226..9fd0363f57 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
@@ -338,13 +338,13 @@ TEST_F(FusionCostAnalysis, LoopFusion) {
// tuple = Tuple({sub, sub, mul, C1})
HloComputation::Builder builder(TestName());
auto c1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
/*from=*/0.0f, /*to=*/1.0f, /*rows=*/2, /*cols=*/2)));
auto c2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
/*from=*/1.0f, /*to=*/2.0f, /*rows=*/2, /*cols=*/2)));
auto c3 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
/*from=*/2.0f, /*to=*/3.0f, /*rows=*/2, /*cols=*/2)));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, c1, c2));
@@ -391,9 +391,9 @@ TEST_F(FusionCostAnalysis, NoLayout) {
HloComputation::Builder builder(TestName());
auto c1 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR4FromArray4D(Array4D<float>(2, 3, 4, 5))));
+ LiteralUtil::CreateR4FromArray4D(Array4D<float>(2, 3, 4, 5))));
auto c2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>({1, 2, 3})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({1, 2, 3})));
auto broadcast = builder.AddInstruction(
HloInstruction::CreateBroadcast(shape_without_layout, c2, {1}));
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
index 0fb65c845a..90d2be118d 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
@@ -261,9 +262,9 @@ StatusOr<HloInstruction*> PadVectorWithZeros(HloInstruction* operand,
padding_config_dim.set_edge_padding_high(zeros_to_append);
*padding_config.add_dimensions() = padding_config_dim;
- HloInstruction* zero =
- computation->AddInstruction(HloInstruction::CreateConstant(
- MakeUnique<Literal>(Literal::Zero(operand->shape().element_type()))));
+ HloInstruction* zero = computation->AddInstruction(
+ HloInstruction::CreateConstant(MakeUnique<Literal>(
+ LiteralUtil::Zero(operand->shape().element_type()))));
return MakePadHlo(operand, zero, padding_config);
}
@@ -272,7 +273,7 @@ StatusOr<HloInstruction*> BroadcastZeros(
ArraySlice<int64> broadcast_dimensions) {
HloInstruction* zero =
computation->AddInstruction(HloInstruction::CreateConstant(
- MakeUnique<Literal>(Literal::Zero(element_type))));
+ MakeUnique<Literal>(LiteralUtil::Zero(element_type))));
return MakeBroadcastHlo(zero, /*broadcast_dimensions=*/{},
/*result_shape_bounds=*/broadcast_dimensions);
}
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc
index 7e7c4f95fe..60d3e71757 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc
@@ -60,8 +60,8 @@ TEST_F(HloCreationUtilsTest, CollapseFirst1Dim) {
HloEvaluator evaluator;
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
evaluator.Evaluate<std::unique_ptr<Literal>>(
- *module, {Literal::CreateR1<int32>({3, 4})}));
- CHECK_EQ(*result_literal, *Literal::CreateR1<int32>({3, 4}));
+ *module, {LiteralUtil::CreateR1<int32>({3, 4})}));
+ CHECK_EQ(*result_literal, *LiteralUtil::CreateR1<int32>({3, 4}));
}
TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) {
@@ -82,10 +82,10 @@ TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) {
std::unique_ptr<Literal> result_literal,
evaluator.Evaluate<std::unique_ptr<Literal>>(
*module,
- {Literal::CreateR3<int32>(
+ {LiteralUtil::CreateR3<int32>(
{{{1, 2}, {3, 4}, {5, 6}}, {{-1, -2}, {-3, -4}, {-5, -6}}})}));
CHECK_EQ(*result_literal,
- *Literal::CreateR2<int32>(
+ *LiteralUtil::CreateR2<int32>(
{{1, 2}, {3, 4}, {5, 6}, {-1, -2}, {-3, -4}, {-5, -6}}));
}
@@ -103,10 +103,11 @@ TEST_F(HloCreationUtilsTest, Prepend1DegenerateDim) {
entry_computation->set_root_instruction(with_1_degenerate_dim_prepended);
HloEvaluator evaluator;
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
- *module, {Literal::CreateR1<int32>({9, 10})}));
- CHECK_EQ(*result_literal, *Literal::CreateR2<int32>({{9, 10}}));
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Literal> result_literal,
+ evaluator.Evaluate<std::unique_ptr<Literal>>(
+ *module, {LiteralUtil::CreateR1<int32>({9, 10})}));
+ CHECK_EQ(*result_literal, *LiteralUtil::CreateR2<int32>({{9, 10}}));
}
TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) {
@@ -123,10 +124,11 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) {
entry_computation->set_root_instruction(with_2_degenerate_dims_prepended);
HloEvaluator evaluator;
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
- *module, {Literal::CreateR1<int32>({9, 10})}));
- CHECK_EQ(*result_literal, *Literal::CreateR3<int32>({{{9, 10}}}));
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Literal> result_literal,
+ evaluator.Evaluate<std::unique_ptr<Literal>>(
+ *module, {LiteralUtil::CreateR1<int32>({9, 10})}));
+ CHECK_EQ(*result_literal, *LiteralUtil::CreateR3<int32>({{{9, 10}}}));
}
TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) {
@@ -145,8 +147,8 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) {
HloEvaluator evaluator;
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
evaluator.Evaluate<std::unique_ptr<Literal>>(
- *module, {Literal::CreateR0<int32>(9)}));
- CHECK_EQ(*result_literal, *Literal::CreateR2<int32>({{9}}));
+ *module, {LiteralUtil::CreateR0<int32>(9)}));
+ CHECK_EQ(*result_literal, *LiteralUtil::CreateR2<int32>({{9}}));
}
TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) {
@@ -166,9 +168,9 @@ TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Literal> result_literal,
evaluator.Evaluate<std::unique_ptr<Literal>>(
- *module, {Literal::CreateR1<int32>({1, 2, 3, 4, 5, 6})}));
+ *module, {LiteralUtil::CreateR1<int32>({1, 2, 3, 4, 5, 6})}));
CHECK_EQ(*result_literal,
- *Literal::CreateR3<int32>({{{1, 2}}, {{3, 4}}, {{5, 6}}}));
+ *LiteralUtil::CreateR3<int32>({{{1, 2}}, {{3, 4}}, {{5, 6}}}));
}
TEST_F(HloCreationUtilsTest, PadVectorWithZeros) {
@@ -188,8 +190,8 @@ TEST_F(HloCreationUtilsTest, PadVectorWithZeros) {
HloEvaluator evaluator;
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
evaluator.Evaluate<std::unique_ptr<Literal>>(
- *module, {Literal::CreateR1<int32>({3, 4})}));
- CHECK_EQ(*result_literal, *Literal::CreateR1<int32>({0, 0, 0, 3, 4, 0}));
+ *module, {LiteralUtil::CreateR1<int32>({3, 4})}));
+ CHECK_EQ(*result_literal, *LiteralUtil::CreateR1<int32>({0, 0, 0, 3, 4, 0}));
}
TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) {
@@ -209,8 +211,8 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) {
HloEvaluator evaluator;
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
evaluator.Evaluate<std::unique_ptr<Literal>>(
- *module, {Literal::CreateR0<int32>(0)}));
- CHECK_EQ(*result_literal, *Literal::CreateR2<int32>({{0, 0}, {0, 0}}));
+ *module, {LiteralUtil::CreateR0<int32>(0)}));
+ CHECK_EQ(*result_literal, *LiteralUtil::CreateR2<int32>({{0, 0}, {0, 0}}));
}
TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) {
@@ -230,9 +232,9 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) {
HloEvaluator evaluator;
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
evaluator.Evaluate<std::unique_ptr<Literal>>(
- *module, {Literal::CreateR0<float>(0.0f)}));
+ *module, {LiteralUtil::CreateR0<float>(0.0f)}));
CHECK_EQ(*result_literal,
- *Literal::CreateR2<float>({{0.0f, 0.0f}, {0.0f, 0.0f}}));
+ *LiteralUtil::CreateR2<float>({{0.0f, 0.0f}, {0.0f, 0.0f}}));
}
} // namespace
diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc
index a0ee889623..06484f4012 100644
--- a/tensorflow/compiler/xla/service/hlo_cse.cc
+++ b/tensorflow/compiler/xla/service/hlo_cse.cc
@@ -24,7 +24,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_domain_map.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -143,10 +143,8 @@ StatusOr<bool> HloCSE::Run(HloModule* module) {
if (instruction->operand_count() == 0) {
continue;
}
- // Skip instructions which have side effects or are a domain (which must
- // not be CSE-ed).
- if (instruction->HasSideEffect() ||
- instruction->opcode() == HloOpcode::kDomain) {
+ // Skip instructions which have side effects.
+ if (instruction->HasSideEffect()) {
continue;
}
diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc
index 16db374566..76b9c66651 100644
--- a/tensorflow/compiler/xla/service/hlo_cse_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -53,9 +53,9 @@ TEST_F(HloCseTest, CombineTwoConstants) {
// Test that two identical constants are commoned.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
builder.AddInstruction(HloInstruction::CreateBinary(
constant1->shape(), HloOpcode::kAdd, constant1, constant2));
@@ -72,7 +72,7 @@ TEST_F(HloCseTest, CombineTwoConstants) {
EXPECT_EQ(42.0f, constant->literal().Get<float>({}));
auto result = ExecuteAndTransfer(std::move(module), {});
- auto expected = Literal::CreateR0<float>(84.0);
+ auto expected = LiteralUtil::CreateR0<float>(84.0);
EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
}
@@ -81,10 +81,10 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) {
// the pass is not layout sensitive.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
constant1->shape(), HloOpcode::kAdd, constant1, constant2));
@@ -104,7 +104,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) {
EXPECT_THAT(add, op::Add(first_operand, first_operand));
auto result = ExecuteAndTransfer(std::move(module), {});
- auto expected = Literal::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
+ auto expected = LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
}
@@ -113,10 +113,10 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) {
// if the pass is layout sensitive.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
constant1->shape(), HloOpcode::kAdd, constant1, constant2));
@@ -134,7 +134,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) {
EXPECT_THAT(add, op::Add(constant1, constant2));
auto result = ExecuteAndTransfer(std::move(module), {});
- auto expected = Literal::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
+ auto expected = LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
}
@@ -144,20 +144,20 @@ TEST_F(HloCseTest, ConstantsSameValueDifferentType) {
auto builder = HloComputation::Builder(TestName());
std::vector<HloInstruction*> constants;
constants.push_back(builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<uint32>(42))));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(42))));
constants.push_back(builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(42))));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(42))));
constants.push_back(builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<uint64>(42.0))));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint64>(42.0))));
constants.push_back(builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int64>(42.0))));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int64>(42.0))));
constants.push_back(builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<double>(42.0))));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<double>(42.0))));
constants.push_back(builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f))));
// Duplicate the float constant to verify something happens.
constants.push_back(builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f))));
const Shape shape_r0 = ShapeUtil::MakeShape(F32, {});
for (int64 i = 0; i < constants.size(); ++i) {
@@ -188,13 +188,13 @@ TEST_F(HloCseTest, NonscalarConstants) {
// Test that identical nonscalar constants are merged.
auto builder = HloComputation::Builder(TestName());
auto common_constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
auto common_constant2 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
// Create a constant which has the same shape but a different value.
auto uncommon_constant =
builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}})));
+ LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}})));
// Tie the constants together with a tuple. This makes it easier to refer to
// the constant instructions via their use.
@@ -223,7 +223,7 @@ TEST_F(HloCseTest, IdenticalInstructions) {
// Test that three identical instructions are commoned.
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
constant->shape(), HloOpcode::kExp, constant));
auto exp2 = builder.AddInstruction(HloInstruction::CreateUnary(
@@ -253,7 +253,7 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsSensitive) {
// commoned if the pass is layout sensitive.
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
constant->shape(), HloOpcode::kExp, constant));
@@ -284,7 +284,7 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsInsensitive) {
// the pass is layout insensitive.
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
constant->shape(), HloOpcode::kExp, constant));
@@ -362,7 +362,7 @@ TEST_F(HloCseTest, IdenticalExpressions) {
// The *1 instructions should be merged with the *2 instructions.
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
auto negate1 = builder.AddInstruction(HloInstruction::CreateUnary(
constant->shape(), HloOpcode::kNegate, constant));
@@ -400,9 +400,9 @@ TEST_F(HloCseTest, DoNotCombineRng) {
// Test that two RNG ops are not commoned.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
auto rng1 = builder.AddInstruction(HloInstruction::CreateRng(
ShapeUtil::MakeShape(F32, {}), RandomDistribution::RNG_UNIFORM,
{constant1, constant2}));
@@ -442,9 +442,9 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) {
Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
auto builder = HloComputation::Builder(TestName() + "_rng_fun");
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
auto rng = builder.AddInstruction(HloInstruction::CreateRng(
scalar_shape, RandomDistribution::RNG_UNIFORM, {constant1, constant2}));
auto param = builder.AddInstruction(HloInstruction::CreateParameter(
@@ -459,7 +459,7 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) {
{
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>({5.0f})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({5.0f})));
auto rng1 = builder.AddInstruction(
HloInstruction::CreateMap(constant->shape(), {constant}, rng_function));
auto rng2 = builder.AddInstruction(
@@ -521,9 +521,9 @@ TEST_F(HloCseTest, ConstantsSameValueInDifferentDomains) {
// in this case) are not collapsed.
auto builder = HloComputation::Builder(TestName());
builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<uint32>(42)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(42)));
builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<uint32>(42)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(42)));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
@@ -536,5 +536,40 @@ TEST_F(HloCseTest, ConstantsSameValueInDifferentDomains) {
EXPECT_EQ(2, computation->instruction_count());
}
+TEST_F(HloCseTest, Domain) {
+ auto module = ParseHloString(R"(
+HloModule module
+ENTRY %entry {
+ %param = f32[] parameter(0), sharding={maximal device=0}
+ %domain.0 = f32[] domain(%param),
+ domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}}
+ %domain.1 = f32[] domain(%param),
+ domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}}
+ %domain.2 = f32[] domain(%param),
+ domain={kind="sharding", entry={maximal device=0}, exit={maximal device=2}}
+ %negate.0 = f32[] negate(%domain.0)
+ %negate.1 = f32[] negate(%domain.1)
+ %negate.2 = f32[] negate(%domain.2)
+ %domain.3 = f32[] domain(%negate.0),
+ domain={kind="sharding", entry={maximal device=1}, exit={maximal device=0}}
+ %domain.4 = f32[] domain(%negate.1),
+ domain={kind="sharding", entry={maximal device=1}, exit={maximal device=0}}
+ %domain.5 = f32[] domain(%negate.2),
+ domain={kind="sharding", entry={maximal device=2}, exit={maximal device=0}}
+ %add = f32[] add(%domain.3, %domain.4)
+ ROOT %sub = f32[] subtract(%add, %domain.5)
+})")
+ .ValueOrDie();
+
+ HloCSE cse(/*is_layout_sensitive=*/false);
+ EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
+ LOG(INFO) << "AAAAA " << module->ToString();
+ const HloInstruction* sub = module->entry_computation()->root_instruction();
+ const HloInstruction* add = sub->operand(0);
+ EXPECT_EQ(add->operand(0), add->operand(1));
+ EXPECT_NE(add->operand(0), sub->operand(1));
+ EXPECT_NE(add->operand(1), sub->operand(1));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
index f176473366..37bc2d2c9d 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
@@ -101,9 +101,9 @@ TEST_P(HloDataflowAnalysisTest, BinaryOperation) {
// Test the dataflow for a simple binary operation (Add).
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
scalar_shape_, HloOpcode::kAdd, constant1, constant2));
module_->AddEntryComputation(builder.Build());
@@ -198,9 +198,9 @@ TEST_P(HloDataflowAnalysisTest, NestedTuple) {
// Verify the dataflow through a nested tuple.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto nested_tuple = builder.AddInstruction(
@@ -259,9 +259,9 @@ TEST_P(HloDataflowAnalysisTest, SingleCall) {
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto call = builder.AddInstruction(HloInstruction::CreateCall(
scalar_shape_, {constant1, constant2}, called_computation));
module_->AddEntryComputation(builder.Build());
@@ -308,9 +308,9 @@ TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithSameArguments) {
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto call1 = builder.AddInstruction(HloInstruction::CreateCall(
scalar_shape_, {constant1, constant2}, called_computation));
auto call2 = builder.AddInstruction(HloInstruction::CreateCall(
@@ -362,9 +362,9 @@ TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithDifferentArguments) {
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto call1 = builder.AddInstruction(HloInstruction::CreateCall(
scalar_shape_, {constant1, constant2}, called_computation));
auto call2 = builder.AddInstruction(HloInstruction::CreateCall(
@@ -426,9 +426,9 @@ TEST_P(HloDataflowAnalysisTest, NestedCalls) {
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto call = builder.AddInstruction(HloInstruction::CreateCall(
scalar_shape_, {constant1, constant2}, outer_computation));
module_->AddEntryComputation(builder.Build());
@@ -493,15 +493,15 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) {
auto cond_param = cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "param"));
auto cond_constant = cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
HloComputation* condition =
module_->AddEmbeddedComputation(cond_builder.Build());
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto xla_while = builder.AddInstruction(
@@ -594,15 +594,15 @@ TEST_P(HloDataflowAnalysisTest, SequentialWhiles) {
cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "param"));
cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
HloComputation* condition =
module_->AddEmbeddedComputation(cond_builder.Build());
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto xla_while0 = builder.AddInstruction(
@@ -653,7 +653,7 @@ TEST_P(HloDataflowAnalysisTest, NestedWhiles) {
cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "param"));
cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
HloComputation* condition =
module_->AddEmbeddedComputation(cond_builder.Build());
@@ -691,9 +691,9 @@ TEST_P(HloDataflowAnalysisTest, NestedWhiles) {
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto entry_while = builder.AddInstruction(
@@ -780,15 +780,15 @@ TEST_P(HloDataflowAnalysisTest, SwizzlingWhile) {
auto cond_param = cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "param"));
cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
HloComputation* condition =
module_->AddEmbeddedComputation(cond_builder.Build());
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto xla_while = builder.AddInstruction(
@@ -840,11 +840,11 @@ TEST_P(HloDataflowAnalysisTest, ArraySelect) {
// Test a kSelect of an array value.
auto builder = HloComputation::Builder(TestName());
auto pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto select = builder.AddInstruction(HloInstruction::CreateTernary(
scalar_shape_, HloOpcode::kSelect, pred, constant1, constant2));
@@ -863,15 +863,15 @@ TEST_P(HloDataflowAnalysisTest, TupleSelect) {
// Test a kTupleSelect. Non-top-level element flow through the instruction.
auto builder = HloComputation::Builder(TestName());
auto pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto constant3 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
auto constant4 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(4.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(4.0)));
auto tuple1 =
builder.AddInstruction(HloInstruction::CreateTuple({constant1}));
auto tuple2 =
@@ -939,17 +939,17 @@ TEST_P(HloDataflowAnalysisTest, NestedTupleSelect) {
// Test kTupleSelect of a nested tuple.
auto builder = HloComputation::Builder(TestName());
auto pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto constant3 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
auto constant4 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(4.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(4.0)));
auto constant5 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(5.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(5.0)));
auto inner_tuple1 = builder.AddInstruction(
HloInstruction::CreateTuple({constant2, constant3}));
auto tuple1 = builder.AddInstruction(
@@ -1025,18 +1025,18 @@ TEST_P(HloDataflowAnalysisTest, TupleSelectToWhile) {
cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "param"));
cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
HloComputation* condition =
module_->AddEmbeddedComputation(cond_builder.Build());
auto pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto constant3 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
auto tuple1 =
builder.AddInstruction(HloInstruction::CreateTuple({constant1}));
auto tuple2 =
@@ -1088,7 +1088,7 @@ TEST_P(HloDataflowAnalysisTest, BitcastDefinesValue) {
// Test the bitcast_defines_value flag to the dataflow analysis.
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary(
scalar_shape_, HloOpcode::kBitcast, constant));
@@ -1157,7 +1157,7 @@ TEST_P(HloDataflowAnalysisTest, SendAndSendDone) {
auto builder = HloComputation::Builder(TestName());
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
- auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({}));
+ auto token = builder.AddInstruction(HloInstruction::CreateToken());
auto send = builder.AddInstruction(
HloInstruction::CreateSend(param, token, /*channel_id=*/0));
auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send));
@@ -1182,7 +1182,7 @@ TEST_P(HloDataflowAnalysisTest, RecvAndRecvDone) {
// Test that a RecvDone forwards its operand tuple element at {0} to element
// {0} of the output.
auto builder = HloComputation::Builder(TestName());
- auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({}));
+ auto token = builder.AddInstruction(HloInstruction::CreateToken());
auto recv = builder.AddInstruction(
HloInstruction::CreateRecv(scalar_shape_, token, /*channel_id=*/0));
auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv));
@@ -1309,13 +1309,13 @@ TEST_P(HloDataflowAnalysisTest, WhileParameters_Sequential) {
auto body_param = body_builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape_, "body_param"));
auto constant = body_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto exp = body_builder.AddInstruction(
HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kExp, constant));
auto add = body_builder.AddInstruction(HloInstruction::CreateBinary(
scalar_shape_, HloOpcode::kAdd, exp, body_param));
auto dead_constant = body_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto dead_negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
scalar_shape_, HloOpcode::kNegate, dead_constant));
HloComputation* body = module_->AddEmbeddedComputation(
@@ -1325,7 +1325,7 @@ TEST_P(HloDataflowAnalysisTest, WhileParameters_Sequential) {
auto cond_param = cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape_, "cond_param"));
auto cond_constant = cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
HloComputation* condition =
module_->AddEmbeddedComputation(cond_builder.Build());
@@ -1576,11 +1576,11 @@ TEST_P(HloDataflowAnalysisTest, ConditionalWithIdentity) {
auto builder = HloComputation::Builder(TestName());
auto pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(56.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(56.0f)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(12.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(12.0f)));
auto conditional = builder.AddInstruction(HloInstruction::CreateConditional(
scalar_shape_, pred, constant1, true_computation, constant2,
false_computation));
@@ -1667,11 +1667,11 @@ TEST_P(HloDataflowAnalysisTest, ConditionalTakingTupleOperand) {
auto builder = HloComputation::Builder(TestName());
auto pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(56.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(56.0f)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(12.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(12.0f)));
auto tuple_operand = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto conditional = builder.AddInstruction(HloInstruction::CreateConditional(
@@ -1797,15 +1797,15 @@ TEST_P(HloDataflowAnalysisTest, NestedConditionals) {
// Build entry computation.
auto builder = HloComputation::Builder(TestName());
auto pred1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
auto pred2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.2f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.2f)));
auto constant3 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(3.3f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.3f)));
auto tuple_operand = builder.AddInstruction(
HloInstruction::CreateTuple({pred2, constant1, constant2}));
auto conditional = builder.AddInstruction(HloInstruction::CreateConditional(
@@ -1943,9 +1943,9 @@ TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) {
// Create a DynamicUpdateSlice instruction of tuple element 1.
auto starts = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int32>({2})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({2})));
auto update = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({2.f, 2.f, 2.f})));
+ LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
auto dynamic_update_slice =
builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
data_shape, gte1, update, starts));
@@ -2048,7 +2048,7 @@ TEST_F(CanShareOperandBufferWithUserTest,
Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
auto one = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto operand = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape, one, {1}));
@@ -2076,7 +2076,7 @@ TEST_F(CanShareOperandBufferWithUserTest,
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, data_shape, "param0"));
auto index = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int64>({0, 0})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int64>({0, 0})));
auto ds = builder.AddInstruction(
HloInstruction::CreateDynamicSlice(slice_shape, param, index, {1, 2, 2}));
@@ -2144,9 +2144,9 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) {
// Create a DynamicUpdateSlice instruction of tuple element 1.
auto starts = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int32>({2})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({2})));
auto update = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({2.f, 2.f, 2.f})));
+ LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
auto dynamic_update_slice =
builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
data_shape, gte1, update, starts));
@@ -2184,9 +2184,9 @@ TEST_F(CanShareOperandBufferWithUserTest,
// Create a DynamicUpdateSlice instruction of tuple element 1.
auto starts = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int32>({2})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({2})));
auto update = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({2.f, 2.f, 2.f})));
+ LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
auto dynamic_update_slice =
builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
data_shape_bf16, convert1, update, starts));
@@ -2237,9 +2237,9 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) {
Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
auto a = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1.0, 0.0}, {0.0, 1.0}})));
+ LiteralUtil::CreateR2<float>({{1.0, 0.0}, {0.0, 1.0}})));
auto b = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
+ LiteralUtil::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
@@ -2248,7 +2248,7 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) {
HloInstruction::CreateDot(data_shape, a, b, dot_dnums));
auto one = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto add_operand = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape, one, {1}));
@@ -2270,7 +2270,7 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) {
Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
auto one = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto operand = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape, one, {1}));
@@ -2278,7 +2278,7 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) {
HloInstruction::CreateReverse(data_shape, operand, {0, 1}));
auto two = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
+ LiteralUtil::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, reverse, two));
@@ -2298,13 +2298,13 @@ TEST_F(CanShareOperandBufferWithUserTest, FusionCanShareBufferCustomized) {
Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
auto one = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto operand = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape, one, {1}));
auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
data_shape, HloOpcode::kMultiply, operand, operand));
auto two = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
+ LiteralUtil::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, mul, two));
@@ -2370,7 +2370,7 @@ TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) {
auto sub_param = sub_builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "sub_param"));
auto one = sub_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto ones = sub_builder.AddInstruction(
HloInstruction::CreateBroadcast(shape, one, {1}));
auto add = sub_builder.AddInstruction(
diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc
index f5524dc6fe..26e3736e01 100644
--- a/tensorflow/compiler/xla/service/hlo_dce_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc
@@ -53,9 +53,9 @@ TEST_F(HloDceTest, NoDeadCode) {
// Verify that no dead code is removed from a computation with no dead code.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(123.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0f)));
builder.AddInstruction(HloInstruction::CreateBinary(
constant1->shape(), HloOpcode::kAdd, constant1, constant2));
@@ -74,8 +74,8 @@ TEST_F(HloDceTest, InstructionsWithSideEffect) {
// Verify that side-effect instructions (Send in this test) are not removed.
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
- auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({}));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
+ auto token = builder.AddInstruction(HloInstruction::CreateToken());
builder.AddInstruction(
HloInstruction::CreateSend(constant, token, /*channel_id=*/0));
builder.AddInstruction(HloInstruction::CreateTuple({}));
@@ -127,9 +127,9 @@ TEST_F(HloDceTest, ControlDependencies) {
// Verify that instructions with control dependencies are not removed.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(123.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0f)));
// Create two dead instructions: a negate and an add.
auto dead_negate = builder.AddInstruction(HloInstruction::CreateUnary(
@@ -224,7 +224,7 @@ TEST_F(HloDceTest, CalledComputationWithSideEffect) {
auto param = cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "cond_param"));
auto constant = cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
cond_builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, param, constant));
}
@@ -235,8 +235,7 @@ TEST_F(HloDceTest, CalledComputationWithSideEffect) {
{
auto param = body_builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "param"));
- auto token =
- body_builder.AddInstruction(HloInstruction::CreateAfterAll({}));
+ auto token = body_builder.AddInstruction(HloInstruction::CreateToken());
auto infeed = body_builder.AddInstruction(
HloInstruction::CreateInfeed(shape, token, ""));
body_builder.AddInstruction(
@@ -280,8 +279,8 @@ TEST_F(HloDceTest, CalledComputationWithNestedSideEffect) {
{
auto param = nested_callee_builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "param"));
- auto token = nested_callee_builder.AddInstruction(
- HloInstruction::CreateAfterAll({}));
+ auto token =
+ nested_callee_builder.AddInstruction(HloInstruction::CreateToken());
nested_callee_builder.AddInstruction(
HloInstruction::CreateOutfeed(shape, param, token, ""));
}
@@ -346,12 +345,12 @@ TEST_F(HloDceTest, RemoveDeadSubcomputation) {
builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/0, ShapeUtil::MakeShape(F32, {100}), "param0")),
builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0))),
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0))),
/*dimensions_to_reduce=*/{0}, reduce_subcomp));
// Add another instruction as the root of the computation.
builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
module->AddEntryComputation(builder.Build());
EXPECT_EQ(module->MakeComputationPostOrder().size(), 2);
@@ -387,7 +386,7 @@ TEST_F(HloDceTest, KeepUsedSubcomputation) {
builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/0, ShapeUtil::MakeShape(F32, {100}), "param0")),
builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0))),
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0))),
/*dimensions_to_reduce=*/{0}, reduce_subcomp));
// Add another instruction as the root of the computation that also uses
@@ -397,7 +396,7 @@ TEST_F(HloDceTest, KeepUsedSubcomputation) {
builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/1, ShapeUtil::MakeShape(F32, {100}), "param1")),
builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0))),
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0))),
/*dimensions_to_reduce=*/{0}, reduce_subcomp));
module->AddEntryComputation(builder.Build());
diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.cc b/tensorflow/compiler/xla/service/hlo_domain_map.cc
index ebd5adb5d5..9e096320db 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_map.cc
+++ b/tensorflow/compiler/xla/service/hlo_domain_map.cc
@@ -41,11 +41,15 @@ namespace xla {
bool HloDomainMap::InSameDomain(HloInstruction* instruction1,
HloInstruction* instruction2) const {
- int64 domain_id1 = FindOrDefault(instruction_to_domain_, instruction1, -1);
- int64 domain_id2 = FindOrDefault(instruction_to_domain_, instruction2, -1);
+ int64 domain_id1 = GetDomainId(instruction1);
+ int64 domain_id2 = GetDomainId(instruction2);
return domain_id1 >= 0 && domain_id1 == domain_id2;
}
+int64 HloDomainMap::GetDomainId(HloInstruction* instruction) const {
+ return FindOrDefault(instruction_to_domain_, instruction, -1);
+}
+
Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) {
TF_RET_CHECK(instruction->opcode() == HloOpcode::kDomain);
// We only check operands, so we are sure to not process the empty domain from
@@ -58,6 +62,11 @@ Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) {
TF_RETURN_IF_ERROR(InsertDomain(std::move(domain)));
}
}
+ if (instruction == instruction->parent()->root_instruction()) {
+ auto domain = MakeUnique<DomainMetadata::Domain>();
+ domain->enter_domains.insert(instruction);
+ TF_RETURN_IF_ERROR(InsertDomain(std::move(domain)));
+ }
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.h b/tensorflow/compiler/xla/service/hlo_domain_map.h
index e62ef763fb..1ca7159725 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_map.h
+++ b/tensorflow/compiler/xla/service/hlo_domain_map.h
@@ -65,6 +65,10 @@ class HloDomainMap {
// currently processing.
bool IsDomainInstruction(HloInstruction* instruction) const;
+ // Retrieves the domain identifier of the instruction, or -1 in case
+ // instruction is not found within any domain.
+ int64 GetDomainId(HloInstruction* instruction) const;
+
private:
HloDomainMap(string domain_kind) : domain_kind_(std::move(domain_kind)) {}
diff --git a/tensorflow/compiler/xla/service/hlo_domain_remover.cc b/tensorflow/compiler/xla/service/hlo_domain_remover.cc
index 1d06040b0e..e2e820002b 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_remover.cc
+++ b/tensorflow/compiler/xla/service/hlo_domain_remover.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_domain_remover.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
-#include "tensorflow/compiler/xla/service/hlo_domain_isolator.h"
#include "tensorflow/compiler/xla/service/hlo_domain_map.h"
+#include "tensorflow/compiler/xla/service/hlo_domain_verifier.h"
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -43,46 +43,8 @@ class HloDomainRemover::RunContext {
Status HloDomainRemover::RunContext::VerifyAndNormalizeDomain(
const DomainMetadata::Domain& domain) {
- // Verify that the whole kDomain frontier bounding the instruction reach set,
- // has matching metadata.
- // A kDomain instruction has two sides of metadata, a user facing and an
- // operand facing.
- // A reachable instruction set can make contact with a kDomain instruction on
- // a user facing side (the kDomain is operand of the instruction), or on a
- // operand facing side (the kDomain is user of the instruction).
- // And depending on the contact side, the proper metadata object
- // (user_side_metadata() vs. operand_side_metadata()) needs to be used for
- // consistency checks.
- const DomainMetadata* ref_metadata = nullptr;
- VLOG(4) << "Reach set:";
- for (HloInstruction* instruction : domain.instructions) {
- VLOG(4) << " " << instruction->name();
- }
- VLOG(4) << " Domains:";
- for (HloInstruction* instruction : domain.enter_domains) {
- const DomainMetadata& meta = instruction->user_side_metadata();
- VLOG(4) << " User side: " << instruction->name();
- VLOG(4) << " " << meta.ToString();
- if (ref_metadata == nullptr) {
- ref_metadata = &meta;
- } else {
- TF_RET_CHECK(meta.Matches(*ref_metadata))
- << "Metadata mismatch at instruction " << instruction->name() << " : "
- << meta.ToString() << " vs " << ref_metadata->ToString();
- }
- }
- for (HloInstruction* instruction : domain.exit_domains) {
- const DomainMetadata& meta = instruction->operand_side_metadata();
- VLOG(4) << " Operand side: " << instruction->name();
- VLOG(4) << " " << meta.ToString();
- if (ref_metadata == nullptr) {
- ref_metadata = &meta;
- } else {
- TF_RET_CHECK(meta.Matches(*ref_metadata))
- << "Metadata mismatch at instruction " << instruction->name() << " : "
- << meta.ToString() << " vs " << ref_metadata->ToString();
- }
- }
+ TF_ASSIGN_OR_RETURN(const DomainMetadata* ref_metadata,
+ HloDomainVerifier::VerifyDomain(domain));
if (ref_metadata != nullptr) {
VLOG(4) << "Applying domain normalization: " << ref_metadata->ToString();
TF_RETURN_IF_ERROR(ref_metadata->NormalizeInstructions(domain));
diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc
index 3859e4cae6..00b2c860a7 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc
@@ -436,6 +436,44 @@ ENTRY entry {
HloSharding::AssignDevice(0)}));
}
+TEST_F(HloDomainTest, EmptyRootDomain) {
+ const char* const hlo_string = R"(
+HloModule Module
+
+ENTRY entry {
+ %param = f32[1] parameter(0), sharding={maximal device=0}
+ %tuple = (f32[1]) tuple(%param),
+ sharding={maximal device=1}
+ ROOT %gte = f32[1] get-tuple-element(%tuple), index=0,
+ sharding={maximal device=1}
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));
+
+ HloDomainIsolator isolator(CreateShardingDomain);
+ TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module));
+ EXPECT_TRUE(isolator_changed);
+
+ EXPECT_TRUE(HasDomainEdge(module, "tuple", "param"));
+ EXPECT_FALSE(HasDomainEdge(module, "gte", "tuple"));
+
+ // Remove %tuple and %gte (tuple simplification)
+ HloInstruction* gte = FindInstruction(module, "gte");
+ HloInstruction* tuple = FindInstruction(module, "tuple");
+ module->entry_computation()->set_root_instruction(tuple->mutable_operand(0));
+ TF_EXPECT_OK(module->entry_computation()->RemoveInstruction(gte));
+ TF_EXPECT_OK(module->entry_computation()->RemoveInstruction(tuple));
+
+ HloDomainRemover remover(ShardingMetadata::KindName(),
+ NormalizeShardingDomain);
+ TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module));
+ EXPECT_TRUE(remover_changed);
+
+ const HloInstruction* root = module->entry_computation()->root_instruction();
+ EXPECT_TRUE(root->has_sharding());
+ EXPECT_EQ(root->sharding(), HloSharding::AssignDevice(1));
+}
+
// Tests that text dumps of domain instructions can be parsed back, in the
// specific case of null shardings.
TEST_F(HloDomainTest, DumpParseNullSharding) {
diff --git a/tensorflow/compiler/xla/service/hlo_domain_verifier.cc b/tensorflow/compiler/xla/service/hlo_domain_verifier.cc
new file mode 100644
index 0000000000..751fc677e2
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_domain_verifier.cc
@@ -0,0 +1,124 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_domain_verifier.h"
+
+#include <set>
+
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_domain_map.h"
+#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/types.h"
+
+namespace xla {
+
+class HloDomainVerifier::RunContext {
+ public:
+ RunContext(HloModule* module, HloDomainVerifier* verifier)
+ : module_(module), verifier_(verifier) {}
+
+ Status Run();
+
+ private:
+ // If the verifier caller passed an empty vector for kinds, we collect all the
+ // avalable domain types.
+ Status PopulateDomainKinds();
+
+ HloModule* module_;
+ HloDomainVerifier* verifier_;
+};
+
+Status HloDomainVerifier::RunContext::PopulateDomainKinds() {
+ if (verifier_->kinds_.empty()) {
+ // The caller specified no domain kinds, collect all the ones available.
+ std::set<string> kinds;
+ for (HloComputation* computation : module_->computations()) {
+ for (HloInstruction* instruction : computation->instructions()) {
+ if (instruction->opcode() == HloOpcode::kDomain) {
+ TF_RET_CHECK(instruction->user_side_metadata().Kind() ==
+ instruction->operand_side_metadata().Kind())
+ << instruction->ToString();
+ kinds.insert(instruction->user_side_metadata().Kind().ToString());
+ }
+ }
+ }
+ verifier_->kinds_.insert(verifier_->kinds_.end(), kinds.begin(),
+ kinds.end());
+ }
+ return Status::OK();
+}
+
+Status HloDomainVerifier::RunContext::Run() {
+ VLOG(4) << "Running HLO Domain Verifier";
+ TF_RETURN_IF_ERROR(PopulateDomainKinds());
+ for (HloComputation* computation : module_->computations()) {
+ for (auto& kind : verifier_->kinds_) {
+ // First create the domain instruciton sets. A domain instruction set is
+ // the set of instructions whose edges never cross a kDomain instruction.
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<HloDomainMap> domain_map,
+ HloDomainMap::Create(computation, kind));
+ // Verify every domain populated within the map.
+ for (auto& domain : domain_map->GetDomains()) {
+ TF_RETURN_IF_ERROR(VerifyDomain(*domain).status());
+ }
+ }
+ }
+ return Status::OK();
+}
+
+StatusOr<bool> HloDomainVerifier::Run(HloModule* module) {
+ RunContext run_context(module, this);
+ TF_RETURN_IF_ERROR(run_context.Run());
+ return false;
+}
+
+StatusOr<const DomainMetadata*> HloDomainVerifier::VerifyDomain(
+ const DomainMetadata::Domain& domain) {
+ const DomainMetadata* ref_metadata = nullptr;
+ VLOG(4) << "Reach set:";
+ for (HloInstruction* instruction : domain.instructions) {
+ VLOG(4) << " " << instruction->name();
+ }
+ VLOG(4) << " Domains:";
+ for (HloInstruction* instruction : domain.enter_domains) {
+ const DomainMetadata& meta = instruction->user_side_metadata();
+ VLOG(4) << " User side: " << instruction->name();
+ VLOG(4) << " " << meta.ToString();
+ if (ref_metadata == nullptr) {
+ ref_metadata = &meta;
+ } else {
+ TF_RET_CHECK(meta.Matches(*ref_metadata))
+ << "Metadata mismatch at instruction " << instruction->name() << " : "
+ << meta.ToString() << " vs " << ref_metadata->ToString();
+ }
+ }
+ for (HloInstruction* instruction : domain.exit_domains) {
+ const DomainMetadata& meta = instruction->operand_side_metadata();
+ VLOG(4) << " Operand side: " << instruction->name();
+ VLOG(4) << " " << meta.ToString();
+ if (ref_metadata == nullptr) {
+ ref_metadata = &meta;
+ } else {
+ TF_RET_CHECK(meta.Matches(*ref_metadata))
+ << "Metadata mismatch at instruction " << instruction->name() << " : "
+ << meta.ToString() << " vs " << ref_metadata->ToString();
+ }
+ }
+ return ref_metadata;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_domain_verifier.h b/tensorflow/compiler/xla/service/hlo_domain_verifier.h
new file mode 100644
index 0000000000..8e53cf97f8
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_domain_verifier.h
@@ -0,0 +1,65 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_VERIFIER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_VERIFIER_H_
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/hlo_domain_map.h"
+#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace xla {
+
+// Verifies that the domain instructions are consistent, and the each domain is
+// surrounded by the same metadata.
+class HloDomainVerifier : public HloPassInterface {
+ public:
+ HloDomainVerifier(std::vector<string> kinds) : kinds_(std::move(kinds)) {}
+
+ tensorflow::StringPiece name() const override { return "domain_verifier"; }
+
+ StatusOr<bool> Run(HloModule* module) override;
+
+ // Verify that the whole kDomain frontier bounding the instruction reach set,
+ // has matching metadata.
+ // A kDomain instruction has two sides of metadata, a user facing and an
+ // operand facing.
+ // A reachable instruction set can make contact with a kDomain instruction on
+ // a user facing side (the kDomain is operand of the instruction), or on a
+ // operand facing side (the kDomain is user of the instruction).
+ // And depending on the contact side, the proper metadata object
+ // (user_side_metadata() vs. operand_side_metadata()) needs to be used for
+ // consistency checks.
+ // Returns the DomainMetadata pointer which surrounds the domain, and
+ // represents the common metadata within such domain. If the returned
+ // DomainMetadata pointer is nullptr, the input domain had no kDomain
+ // boundary.
+ static StatusOr<const DomainMetadata*> VerifyDomain(
+ const DomainMetadata::Domain& domain);
+
+ private:
+ class RunContext;
+
+ std::vector<string> kinds_;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_VERIFIER_H_
diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc
index 4ed1508d70..c804f4364f 100644
--- a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc
+++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index 47da46bfad..dfdfeb49a2 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/index_util.h"
#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
@@ -135,7 +136,6 @@ StatusOr<std::unique_ptr<Literal>> Compare<complex64>(
} // namespace
-
HloEvaluator::HloEvaluator(int64 max_loop_iterations)
: max_loop_iterations_(max_loop_iterations) {
typed_visitors_[PRED] = MakeUnique<HloEvaluatorTypedVisitor<bool>>(this);
@@ -330,6 +330,24 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseUnaryOp(
return result;
}
+StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateDotOp(
+ const DotDimensionNumbers& dim_numbers, const Literal& lhs,
+ const Literal& rhs) {
+ std::unique_ptr<HloInstruction> lhs_instr =
+ HloInstruction::CreateConstant(lhs.CloneToUnique());
+ std::unique_ptr<HloInstruction> rhs_instr =
+ HloInstruction::CreateConstant(rhs.CloneToUnique());
+
+ TF_ASSIGN_OR_RETURN(
+ Shape dot_shape,
+ ShapeInference::InferDotOpShape(lhs.shape(), rhs.shape(), dim_numbers));
+
+ std::unique_ptr<HloInstruction> cloned_instruction =
+ HloInstruction::CreateDot(dot_shape, lhs_instr.get(), rhs_instr.get(),
+ dim_numbers);
+ return Evaluate(cloned_instruction.get());
+}
+
Status HloEvaluator::HandleParameter(HloInstruction* parameter) {
CHECK_LT(parameter->parameter_number(), arg_literals_.size());
const Literal* input_literal = arg_literals_[parameter->parameter_number()];
@@ -382,7 +400,7 @@ Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) {
ShapeUtil::GetDimension(operand_shape, concat_dim);
}
- auto result_literal = Literal::CreateFromDimensions(
+ auto result_literal = LiteralUtil::CreateFromDimensions(
reference_shape.element_type(), concat_dimensions);
DimensionVector source_indices(rank, 0);
DimensionVector dest_indices(concat_dimensions.size(), 0);
@@ -533,7 +551,7 @@ Status HloEvaluator::HandleTuple(HloInstruction* tuple) {
operand_literals.push_back(&GetEvaluatedLiteralFor(operand));
}
- evaluated_[tuple] = Literal::MakeTuple(operand_literals);
+ evaluated_[tuple] = LiteralUtil::MakeTuple(operand_literals);
return Status::OK();
}
@@ -757,6 +775,12 @@ class OutputWindowIndexToInputIndex {
return ArraySlice<int64>(input_index_);
}
+ // Returns for a given 'input_dim' the corresponding output dimension index,
+ // or -1 if 'input_dim' is an elided window dimension.
+ int64 input_dim_value_to_output_index(int64 input_dim) {
+ return input_dim_value_to_output_index_[input_dim];
+ }
+
private:
// Propagates window dimensions from the output index to input_index_ by
// mutating input_index_ in place.
@@ -774,7 +798,7 @@ class OutputWindowIndexToInputIndex {
// input_dim_value_to_index_vector_[i] tells us how to compute dimension i of
// the input index from the output index. See
- // PropagateOutputIndexToInputIndex.
+ // PropagateOutputIndexWindowDimsToInputIndex.
std::vector<int64> input_dim_value_to_output_index_;
// The result computed by this functor. operator() returns an ArraySlice into
@@ -827,6 +851,8 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) {
// corresponding index in the input shape.
std::vector<int64> input_index(operand.shape().dimensions_size());
std::vector<int64> output_index(gather->shape().dimensions_size());
+ std::vector<int64> input_gather_index_clamped(
+ operand.shape().dimensions_size());
OutputGatherIndexToInputIndex output_gather_index_to_input_index(
&gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(),
@@ -848,14 +874,26 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) {
output_index[i] = output_gather_index[i] + output_window_index[i];
DCHECK_LT(output_index[i], shape.dimensions(i));
}
+ for (int i = 0, e = input_gather_index.size(); i < e; i++) {
+ int64 output_dim =
+ output_window_index_to_input_index.input_dim_value_to_output_index(i);
+ // If 'output_dim' is -1, it means 'i' is an elided window dim. This means
+ // we set the iteration index to 0, so for the purpose of the following
+ // calculations we can consider the output dimension size to be 1.
+ int64 output_dim_size =
+ output_dim == -1 ? 1 : shape.dimensions(output_dim);
+ // Clamp the gather index so that the gather region fits in the operand.
+ // input_gather_index_clamped[i] = clamp(input_gather_index[i], 0,
+ // operand_shape.dimensions(i) -
+ // output_dim_size);
+ input_gather_index_clamped[i] =
+ std::min(operand_shape.dimensions(i) - output_dim_size,
+ std::max(0LL, input_gather_index[i]));
+ }
for (int i = 0, e = input_index.size(); i < e; i++) {
- // TODO(b/74360564): We should implement whatever out of bounds behavior
- // we decide for dynamic-slice here as well.
- input_index[i] = (input_gather_index[i] + input_window_index[i]) %
- operand_shape.dimensions(i);
- if (input_index[i] < 0) {
- input_index[i] += operand_shape.dimensions(i);
- }
+ input_index[i] = input_gather_index_clamped[i] + input_window_index[i];
+ DCHECK_GE(input_index[i], 0);
+ DCHECK_LT(input_index[i], operand_shape.dimensions(i));
}
TF_RETURN_IF_ERROR(
result->CopyElementFrom(operand, input_index, output_index));
@@ -903,7 +941,7 @@ Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) {
}
Status HloEvaluator::HandleAfterAll(HloInstruction* token) {
- evaluated_[token] = Literal::CreateToken();
+ evaluated_[token] = LiteralUtil::CreateToken();
return Status::OK();
}
@@ -1119,7 +1157,7 @@ std::unique_ptr<Literal> EvaluateSortInternal(HloInstruction* sort,
auto result_values_literal = MakeUnique<Literal>(sort->operand(1)->shape());
result_values_literal->PopulateR1(
tensorflow::gtl::ArraySlice<ValueType>(result_values));
- auto result_tuple = Literal::MakeTuple(
+ auto result_tuple = LiteralUtil::MakeTuple(
{result_keys_literal.get(), result_values_literal.get()});
VLOG(3) << "HandleSort result_tuple: " << result_tuple->ToString();
return result_tuple;
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h
index 2850c5cb1a..a4c37ef328 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.h
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -115,6 +116,10 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
StatusOr<std::unique_ptr<Literal>> EvaluateElementwiseUnaryOp(
HloOpcode opcode, const Literal& operand);
+ StatusOr<std::unique_ptr<Literal>> EvaluateDotOp(
+ const DotDimensionNumbers& dim_numbers, const Literal& lhs,
+ const Literal& rhs);
+
protected:
// Make HloEvaluatorTypedVisitor a friend because it is logically part of this
// class.
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
index 42770d848a..5f575b24a1 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_element_type_converter.h"
@@ -112,9 +112,9 @@ class HloEvaluatorTest : public ::testing::WithParamInterface<bool>,
// Verifies that HloEvaluator evaluates a HLO instruction that performs clamp
// with 3 operands.
TEST_P(HloEvaluatorTest, DoesClamp) {
- auto low = Literal::CreateR2<float>({{0.f, 2.f}, {2.f, 4.f}});
- auto value = Literal::CreateR2<float>({{0.f, 5.f}, {0.f, 4.f}});
- auto high = Literal::CreateR2<float>({{2.f, 4.f}, {4.f, 4.f}});
+ auto low = LiteralUtil::CreateR2<float>({{0.f, 2.f}, {2.f, 4.f}});
+ auto value = LiteralUtil::CreateR2<float>({{0.f, 5.f}, {0.f, 4.f}});
+ auto high = LiteralUtil::CreateR2<float>({{2.f, 4.f}, {4.f, 4.f}});
Shape shape = low->shape();
HloComputation::Builder b(TestName());
@@ -127,15 +127,15 @@ TEST_P(HloEvaluatorTest, DoesClamp) {
std::unique_ptr<Literal> result = Evaluate();
- auto expected = Literal::CreateR2<float>({{0, 4}, {2, 4}});
+ auto expected = LiteralUtil::CreateR2<float>({{0, 4}, {2, 4}});
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) {
- auto low = Literal::CreateR0<float>(0.f);
- auto value = Literal::CreateR2<float>({{-1.f, 0.f}, {1.f, 2.f}});
- auto high = Literal::CreateR0<float>(1.f);
+ auto low = LiteralUtil::CreateR0<float>(0.f);
+ auto value = LiteralUtil::CreateR2<float>({{-1.f, 0.f}, {1.f, 2.f}});
+ auto high = LiteralUtil::CreateR0<float>(1.f);
Shape shape = value->shape();
HloComputation::Builder b(TestName());
@@ -148,7 +148,7 @@ TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) {
std::unique_ptr<Literal> result = Evaluate();
- auto expected = Literal::CreateR2<float>({{0, 0}, {1, 1}});
+ auto expected = LiteralUtil::CreateR2<float>({{0, 0}, {1, 1}});
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
@@ -156,9 +156,9 @@ TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) {
// Verifies that HloEvaluator evaluates a HLO instruction that performs select
// with 3 operands.
TEST_P(HloEvaluatorTest, DoesSelect) {
- auto pred = Literal::CreateR2<bool>({{true, false}, {false, true}});
- auto on_true = Literal::CreateR2<float>({{2.f, 4.f}, {4.f, 4.f}});
- auto on_false = Literal::CreateR2<float>({{0.f, 5.f}, {0.f, 4.f}});
+ auto pred = LiteralUtil::CreateR2<bool>({{true, false}, {false, true}});
+ auto on_true = LiteralUtil::CreateR2<float>({{2.f, 4.f}, {4.f, 4.f}});
+ auto on_false = LiteralUtil::CreateR2<float>({{0.f, 5.f}, {0.f, 4.f}});
Shape shape = on_true->shape();
HloComputation::Builder b(TestName());
@@ -173,7 +173,7 @@ TEST_P(HloEvaluatorTest, DoesSelect) {
std::unique_ptr<Literal> result = Evaluate({});
- auto expected = Literal::CreateR2<float>({{2, 5}, {0, 4}});
+ auto expected = LiteralUtil::CreateR2<float>({{2, 5}, {0, 4}});
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
@@ -181,46 +181,46 @@ TEST_P(HloEvaluatorTest, DoesSelect) {
// Verifies that HloEvaluator evaluates a HLO instruction that performs
// element-wise addition with 2 operands.
TEST_P(HloEvaluatorTest, DoesAdd) {
- auto lhs = Literal::CreateR2<int64>({{1, 0}, {-100, 4}});
- auto rhs = Literal::CreateR2<int64>({{2, 4}, {4, 4}});
- auto expected = Literal::CreateR2<int64>({{3, 4}, {-96, 8}});
+ auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
+ auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
+ auto expected = LiteralUtil::CreateR2<int64>({{3, 4}, {-96, 8}});
TestBinaryOp(HloOpcode::kAdd, std::move(expected), std::move(lhs),
std::move(rhs));
}
// Verifies that HloEvaluator evaluates a HLO instruction that performs
// element-wise and with 2 operands.
TEST_P(HloEvaluatorTest, DoesAnd) {
- auto lhs = Literal::CreateR2<int64>({{1, 0}, {-100, 4}});
- auto rhs = Literal::CreateR2<int64>({{2, 4}, {4, 4}});
- auto expected = Literal::CreateR2<int64>({{0, 0}, {4, 4}});
+ auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
+ auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
+ auto expected = LiteralUtil::CreateR2<int64>({{0, 0}, {4, 4}});
TestBinaryOp(HloOpcode::kAnd, std::move(expected), std::move(lhs),
std::move(rhs));
}
// Verifies that HloEvaluator evaluates a HLO instruction that performs
// element-wise or with 2 operands.
TEST_P(HloEvaluatorTest, DoesOr) {
- auto lhs = Literal::CreateR2<int64>({{1, 0}, {-100, 4}});
- auto rhs = Literal::CreateR2<int64>({{2, 4}, {4, 4}});
- auto expected = Literal::CreateR2<int64>({{3, 4}, {-100, 4}});
+ auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
+ auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
+ auto expected = LiteralUtil::CreateR2<int64>({{3, 4}, {-100, 4}});
TestBinaryOp(HloOpcode::kOr, std::move(expected), std::move(lhs),
std::move(rhs));
}
// Verifies that HloEvaluator evaluates a HLO instruction that performs
// element-wise or with 2 operands.
TEST_P(HloEvaluatorTest, DoesXor) {
- auto lhs = Literal::CreateR2<int64>({{1, 0}, {-100, 4}});
- auto rhs = Literal::CreateR2<int64>({{2, 4}, {4, 4}});
- auto expected = Literal::CreateR2<int64>({{3, 4}, {-104, 0}});
+ auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
+ auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
+ auto expected = LiteralUtil::CreateR2<int64>({{3, 4}, {-104, 0}});
TestBinaryOp(HloOpcode::kXor, std::move(expected), std::move(lhs),
std::move(rhs));
}
// Verifies that HloEvaluator evaluates a HLO instruction that performs
// element-wise multiply with 2 operands.
TEST_P(HloEvaluatorTest, DoesMultiply) {
- auto lhs = Literal::CreateR2<int32>({{-1, 0}, {-100, 4}});
- auto rhs = Literal::CreateR2<int32>(
+ auto lhs = LiteralUtil::CreateR2<int32>({{-1, 0}, {-100, 4}});
+ auto rhs = LiteralUtil::CreateR2<int32>(
{{std::numeric_limits<int32>::min(), 4}, {4, 4}});
- auto expected = Literal::CreateR2<int32>(
+ auto expected = LiteralUtil::CreateR2<int32>(
{{std::numeric_limits<int32>::min(), 0}, {-400, 16}});
TestBinaryOp(HloOpcode::kMultiply, std::move(expected), std::move(lhs),
std::move(rhs));
@@ -228,17 +228,17 @@ TEST_P(HloEvaluatorTest, DoesMultiply) {
// Verifies that HloEvaluator evaluates a HLO instruction that performs
// element-wise divide with 2 operands.
TEST_P(HloEvaluatorTest, DoesDivideInt64) {
- auto lhs = Literal::CreateR2<int64>({{1, 0}, {-100, 4}});
- auto rhs = Literal::CreateR2<int64>({{2, 4}, {4, 4}});
- auto expected = Literal::CreateR2<int64>({{0, 0}, {-25, 1}});
+ auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
+ auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
+ auto expected = LiteralUtil::CreateR2<int64>({{0, 0}, {-25, 1}});
TestBinaryOp(HloOpcode::kDivide, std::move(expected), std::move(lhs),
std::move(rhs));
}
TEST_P(HloEvaluatorTest, DoesDivideDouble) {
- auto lhs = Literal::CreateR2<double>({{1.0, 0.0}, {-100.0, 4.0}});
- auto rhs = Literal::CreateR2<double>({{2.2, 4.0}, {4.0, 4.0}});
+ auto lhs = LiteralUtil::CreateR2<double>({{1.0, 0.0}, {-100.0, 4.0}});
+ auto rhs = LiteralUtil::CreateR2<double>({{2.2, 4.0}, {4.0, 4.0}});
auto expected =
- Literal::CreateR2<double>({{0.45454545454545453, 0}, {-25, 1}});
+ LiteralUtil::CreateR2<double>({{0.45454545454545453, 0}, {-25, 1}});
TestBinaryOp(HloOpcode::kDivide, std::move(expected), std::move(lhs),
std::move(rhs));
}
@@ -246,54 +246,54 @@ TEST_P(HloEvaluatorTest, DoesDivideDouble) {
// Verifies that HloEvaluator evaluates a HLO instruction that performs
// element-wise abs op with 1 operand.
TEST_P(HloEvaluatorTest, DoesAbsR2) {
- auto operand = Literal::CreateR2<int64>({{1, -20}, {-100, 4}});
- auto expected = Literal::CreateR2<int64>({{1, 20}, {100, 4}});
+ auto operand = LiteralUtil::CreateR2<int64>({{1, -20}, {-100, 4}});
+ auto expected = LiteralUtil::CreateR2<int64>({{1, 20}, {100, 4}});
TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand));
}
TEST_P(HloEvaluatorTest, DoesAbsR0) {
- auto operand = Literal::CreateR0<float>(-1.0f);
- auto expected = Literal::CreateR0<float>(1.0f);
+ auto operand = LiteralUtil::CreateR0<float>(-1.0f);
+ auto expected = LiteralUtil::CreateR0<float>(1.0f);
TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand));
}
TEST_P(HloEvaluatorTest, DoesAbsR1WithZeroSize) {
- auto operand = Literal::CreateR1<float>({});
- auto expected = Literal::CreateR1<float>({});
+ auto operand = LiteralUtil::CreateR1<float>({});
+ auto expected = LiteralUtil::CreateR1<float>({});
TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand));
}
TEST_P(HloEvaluatorTest, DoesNegateR2) {
- auto operand = Literal::CreateR2<int32>(
+ auto operand = LiteralUtil::CreateR2<int32>(
{{0, std::numeric_limits<int32>::min()}, {-1, 4}});
- auto expected =
- Literal::CreateR2<int32>({{0, std::numeric_limits<int>::min()}, {1, -4}});
+ auto expected = LiteralUtil::CreateR2<int32>(
+ {{0, std::numeric_limits<int>::min()}, {1, -4}});
TestUnaryOp(HloOpcode::kNegate, std::move(expected), std::move(operand));
}
TEST_P(HloEvaluatorTest, DoesCosR2) {
- auto operand = Literal::CreateR2<float>({{0, M_PI}, {-M_PI, 2 * M_PI}});
- auto expected = Literal::CreateR2<float>({{1, -1}, {-1, 1}});
+ auto operand = LiteralUtil::CreateR2<float>({{0, M_PI}, {-M_PI, 2 * M_PI}});
+ auto expected = LiteralUtil::CreateR2<float>({{1, -1}, {-1, 1}});
TestUnaryOp(HloOpcode::kCos, std::move(expected), std::move(operand),
use_bfloat16_ ? 0.031250 : 9.5367431640625E-7);
}
TEST_P(HloEvaluatorTest, DoesSinR2) {
- auto operand = Literal::CreateR2<float>({{0, M_PI}, {-M_PI, 2 * M_PI}});
- auto expected = Literal::CreateR2<float>({{0, 0}, {0, 0}});
+ auto operand = LiteralUtil::CreateR2<float>({{0, M_PI}, {-M_PI, 2 * M_PI}});
+ auto expected = LiteralUtil::CreateR2<float>({{0, 0}, {0, 0}});
TestUnaryOp(HloOpcode::kSin, std::move(expected), std::move(operand),
use_bfloat16_ ? 0.031250 : 9.5367431640625E-7);
}
TEST_P(HloEvaluatorTest, DoesNotR2) {
auto operand =
- Literal::CreateR2<int32>({{0, std::numeric_limits<int>::min()},
- {-1, std::numeric_limits<int>::max()}});
+ LiteralUtil::CreateR2<int32>({{0, std::numeric_limits<int>::min()},
+ {-1, std::numeric_limits<int>::max()}});
auto expected =
- Literal::CreateR2<int32>({{-1, std::numeric_limits<int>::max()},
- {0, std::numeric_limits<int>::min()}});
+ LiteralUtil::CreateR2<int32>({{-1, std::numeric_limits<int>::max()},
+ {0, std::numeric_limits<int>::min()}});
TestUnaryOp(HloOpcode::kNot, std::move(expected), std::move(operand));
}
// Verifies that HloEvaluator evaluates a HLO Computation with non-parameter nor
// constant operands.
TEST_P(HloEvaluatorTest, DoesTraverseInstructions) {
- auto lhs = Literal::CreateR2<int64>({{1, 0}, {-100, 4}});
- auto rhs = Literal::CreateR2<int64>({{2, 4}, {4, 4}});
- auto rhs2 = Literal::CreateR2<int64>({{1, -20}, {-100, 4}});
+ auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
+ auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
+ auto rhs2 = LiteralUtil::CreateR2<int64>({{1, -20}, {-100, 4}});
std::vector<const Literal*> args = {lhs.get(), rhs.get(), rhs2.get()};
Shape shape = ShapeUtil::MakeShape(S64, {2, 2});
@@ -314,7 +314,7 @@ TEST_P(HloEvaluatorTest, DoesTraverseInstructions) {
std::unique_ptr<Literal> result = Evaluate(args);
- auto expected = Literal::CreateR2<int64>({{4, -16}, {-196, 12}});
+ auto expected = LiteralUtil::CreateR2<int64>({{4, -16}, {-196, 12}});
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
@@ -324,7 +324,7 @@ TEST_P(HloEvaluatorTest, DoesReshape) {
HloComputation::Builder b(TestName());
const int64 dimensions[] = {11, 8, 7, 5, 9};
TF_ASSERT_OK_AND_ASSIGN(auto literal,
- Literal::CreateRandomLiteral<F32>(
+ LiteralUtil::CreateRandomLiteral<F32>(
ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
auto literal_clone = literal->CloneToUnique();
HloInstruction* literal_instruction =
@@ -349,8 +349,8 @@ TEST_P(HloEvaluatorTest, DoesReshape) {
// Verifies Broadcast operation is correctly evaluated.
TEST_P(HloEvaluatorTest, DoesBroadcast) {
HloComputation::Builder b(TestName());
- auto input_literal = Literal::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}});
- auto output_literal = Literal::CreateR3<int32>(
+ auto input_literal = LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}});
+ auto output_literal = LiteralUtil::CreateR3<int32>(
{{{1, 2}, {3, 4}, {5, 6}}, {{1, 2}, {3, 4}, {5, 6}}});
HloInstruction* literal_instruction = b.AddInstruction(
HloInstruction::CreateConstant(std::move(input_literal)));
@@ -365,8 +365,8 @@ TEST_P(HloEvaluatorTest, DoesBroadcast) {
TEST_P(HloEvaluatorTest, DoesBroadcastScalar) {
HloComputation::Builder b(TestName());
- auto input_literal = Literal::CreateR0<int32>(111);
- auto output_literal = Literal::CreateR2<int32>(
+ auto input_literal = LiteralUtil::CreateR0<int32>(111);
+ auto output_literal = LiteralUtil::CreateR2<int32>(
{{111, 111}, {111, 111}, {111, 111}, {111, 111}, {111, 111}, {111, 111}});
HloInstruction* literal_instruction = b.AddInstruction(
@@ -386,9 +386,9 @@ TEST_P(HloEvaluatorTest, DoesConcatenateSimple) {
HloComputation::Builder b(TestName());
HloInstruction* operand1 = b.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<int64>({{-1, -2}, {100, 200}})));
+ LiteralUtil::CreateR2<int64>({{-1, -2}, {100, 200}})));
HloInstruction* operand2 = b.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<int64>({{-2, -3}, {-100, -200}})));
+ LiteralUtil::CreateR2<int64>({{-2, -3}, {-100, -200}})));
std::vector<HloInstruction*> operands = {operand1, operand2};
@@ -399,8 +399,8 @@ TEST_P(HloEvaluatorTest, DoesConcatenateSimple) {
std::unique_ptr<Literal> result = Evaluate();
- auto expected =
- Literal::CreateR2<int64>({{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}});
+ auto expected = LiteralUtil::CreateR2<int64>(
+ {{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}});
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
@@ -408,9 +408,9 @@ TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) {
HloComputation::Builder b(TestName());
HloInstruction* operand1 = b.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int64>({100, 200})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int64>({100, 200})));
HloInstruction* operand2 = b.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int64>({})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int64>({})));
std::vector<HloInstruction*> operands = {operand1, operand2};
@@ -421,16 +421,16 @@ TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) {
std::unique_ptr<Literal> result = Evaluate();
- auto expected = Literal::CreateR1<int64>({100, 200});
+ auto expected = LiteralUtil::CreateR1<int64>({100, 200});
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, ConvertWithSameLayout) {
HloComputation::Builder b(TestName());
- auto input_literal = Literal::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}});
+ auto input_literal = LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}});
auto expected =
- Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}});
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}});
ASSERT_TRUE(LayoutUtil::LayoutsInShapesEqual(input_literal->shape(),
expected->shape()));
@@ -447,9 +447,9 @@ TEST_P(HloEvaluatorTest, ConvertWithSameLayout) {
TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) {
HloComputation::Builder b(TestName());
- auto input_literal = Literal::CreateR2WithLayout<int32>(
+ auto input_literal = LiteralUtil::CreateR2WithLayout<int32>(
{{1, 2}, {3, 4}, {5, 6}}, LayoutUtil::MakeLayout({0, 1}));
- auto expected = Literal::CreateR2WithLayout<float>(
+ auto expected = LiteralUtil::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}, LayoutUtil::MakeLayout({1, 0}));
ASSERT_FALSE(LayoutUtil::LayoutsInShapesEqual(input_literal->shape(),
expected->shape()));
@@ -478,13 +478,13 @@ PaddingConfig CreatePaddingConfig(
}
TEST_P(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) {
- auto operand = Literal::CreateR2<int32>({{}, {}});
+ auto operand = LiteralUtil::CreateR2<int32>({{}, {}});
HloComputation::Builder b(TestName());
auto operand_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(operand)));
constexpr int32 kPadValue = 10;
- auto pad_value = Literal::CreateR0<int32>(kPadValue);
+ auto pad_value = LiteralUtil::CreateR0<int32>(kPadValue);
auto padding_value_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(pad_value)));
@@ -496,7 +496,7 @@ TEST_P(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) {
std::unique_ptr<Literal> result = Evaluate();
- auto expected = Literal::CreateR2<int32>(
+ auto expected = LiteralUtil::CreateR2<int32>(
{{10, 10}, {10, 10}, {10, 10}, {10, 10}, {10, 10}});
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
@@ -506,11 +506,11 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) {
HloComputation::Builder b(TestName());
Array4D<float> input_array(3, 2, 1, 1, {1, 2, 3, 4, 5, 6});
- auto input = Literal::CreateR4FromArray4D<float>(input_array);
+ auto input = LiteralUtil::CreateR4FromArray4D<float>(input_array);
HloInstruction* input_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(input)));
constexpr float kPadValue = 1.5;
- auto pad_value = Literal::CreateR0<float>(kPadValue);
+ auto pad_value = LiteralUtil::CreateR0<float>(kPadValue);
HloInstruction* pad_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(pad_value)));
@@ -532,7 +532,7 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) {
(*expected_array)(7, 0, 0, 0) = 5.0f;
(*expected_array)(7, 2, 0, 0) = 6.0f;
- auto expected = Literal::CreateR4FromArray4D<float>(*expected_array);
+ auto expected = LiteralUtil::CreateR4FromArray4D<float>(*expected_array);
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
@@ -549,12 +549,12 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) {
// }
auto input_array = MakeUnique<Array2D<float>>(4, 3);
input_array->FillUnique(1.0f);
- auto input = Literal::CreateR2FromArray2D<float>(*input_array);
+ auto input = LiteralUtil::CreateR2FromArray2D<float>(*input_array);
HloInstruction* input_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(input)));
auto pad_value_instruction = b.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.718f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.718f)));
auto r2_padding_on_dim0_dim1 =
CreatePaddingConfig({{{-1, -2, 0}}, {{-2, 4, 0}}});
@@ -574,7 +574,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) {
(*expected_array)(0, 2) = 2.718f;
(*expected_array)(0, 3) = 2.718f;
(*expected_array)(0, 4) = 2.718f;
- auto expected = Literal::CreateR2FromArray2D<float>(*expected_array);
+ auto expected = LiteralUtil::CreateR2FromArray2D<float>(*expected_array);
EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(0.031250)));
}
@@ -590,12 +590,12 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) {
// }
auto input_array = MakeUnique<Array2D<float>>(4, 3);
input_array->FillUnique(1.0f);
- auto input = Literal::CreateR2FromArray2D<float>(*input_array);
+ auto input = LiteralUtil::CreateR2FromArray2D<float>(*input_array);
HloInstruction* input_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(input)));
auto pad_value_instruction = b.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.718f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.718f)));
PaddingConfig padding_config = MakeNoPaddingConfig(2);
@@ -613,7 +613,7 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) {
std::unique_ptr<Literal> result = Evaluate();
auto expected_array = MakeUnique<Array2D<float>>(0, 9);
- auto expected = Literal::CreateR2FromArray2D<float>(*expected_array);
+ auto expected = LiteralUtil::CreateR2FromArray2D<float>(*expected_array);
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
@@ -630,13 +630,13 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) {
// }
auto lhs_array = MakeUnique<Array2D<float>>(4, 1);
lhs_array->FillUnique(1.0f);
- auto lhs_literal = Literal::CreateR2FromArray2D<float>(*lhs_array);
+ auto lhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*lhs_array);
HloInstruction* lhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
// rhs:
// f32[2] { 1, 2 },
- auto rhs_literal = Literal::CreateR2<float>({{1, 2}});
+ auto rhs_literal = LiteralUtil::CreateR2<float>({{1, 2}});
HloInstruction* rhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
@@ -658,7 +658,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) {
{4.f, 8.f},
});
// clang-format on
- auto expected = Literal::CreateR2FromArray2D<float>(expected_array);
+ auto expected = LiteralUtil::CreateR2FromArray2D<float>(expected_array);
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
@@ -669,7 +669,7 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) {
// lhs:
// f32[3]
// { 1, 2, 3 },
- auto lhs_literal = Literal::CreateR1<float>({1, 2, 3});
+ auto lhs_literal = LiteralUtil::CreateR1<float>({1, 2, 3});
HloInstruction* lhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
@@ -681,7 +681,7 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) {
// }
auto rhs_array = MakeUnique<Array2D<float>>(3, 2);
rhs_array->FillUnique(1.0f);
- auto rhs_literal = Literal::CreateR2FromArray2D<float>(*rhs_array);
+ auto rhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*rhs_array);
HloInstruction* rhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
@@ -695,7 +695,7 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) {
std::unique_ptr<Literal> result = Evaluate();
- auto expected = Literal::CreateR1<float>({22.f, 28.f});
+ auto expected = LiteralUtil::CreateR1<float>({22.f, 28.f});
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
@@ -712,7 +712,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) {
// }
auto lhs_array = MakeUnique<Array2D<float>>(4, 3);
lhs_array->FillUnique(1.0f);
- auto lhs_literal = Literal::CreateR2FromArray2D<float>(*lhs_array);
+ auto lhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*lhs_array);
HloInstruction* lhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
@@ -724,7 +724,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) {
// }
auto rhs_array = MakeUnique<Array2D<float>>(3, 2);
rhs_array->FillUnique(1.0f);
- auto rhs_literal = Literal::CreateR2FromArray2D<float>(*rhs_array);
+ auto rhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*rhs_array);
HloInstruction* rhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
@@ -744,7 +744,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) {
{94.f, 124.f},
{130.f, 172.f},
});
- auto expected = Literal::CreateR2FromArray2D<float>(expected_array);
+ auto expected = LiteralUtil::CreateR2FromArray2D<float>(expected_array);
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
@@ -753,12 +753,12 @@ TEST_P(HloEvaluatorTest, SimpleConv1D) {
HloComputation::Builder b(TestName());
Array3D<float> lhs_array = {{{1, 2, 3}}};
- auto lhs_literal = Literal::CreateR3FromArray3D<float>(lhs_array);
+ auto lhs_literal = LiteralUtil::CreateR3FromArray3D<float>(lhs_array);
HloInstruction* lhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
Array3D<float> rhs_array = {{{3.f, 4.f}}};
- auto rhs_literal = Literal::CreateR3FromArray3D<float>(rhs_array);
+ auto rhs_literal = LiteralUtil::CreateR3FromArray3D<float>(rhs_array);
HloInstruction* rhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
@@ -792,7 +792,7 @@ TEST_P(HloEvaluatorTest, SimpleConv1D) {
std::unique_ptr<Literal> result = Evaluate();
Array3D<float> expected_array = {{{11.f, 18.f, 9.f}}};
- auto expected = Literal::CreateR3FromArray3D<float>(expected_array);
+ auto expected = LiteralUtil::CreateR3FromArray3D<float>(expected_array);
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
@@ -809,7 +809,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) {
{13, 14, 15, 16},
}));
// clang-format on
- auto lhs_literal = Literal::CreateR4FromArray4D<float>(lhs_array);
+ auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(lhs_array);
HloInstruction* lhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
@@ -820,7 +820,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) {
{7, 8},
}));
// clang-format on
- auto rhs_literal = Literal::CreateR4FromArray4D<float>(rhs_array);
+ auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(rhs_array);
HloInstruction* rhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
@@ -854,7 +854,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) {
{149, 160, 171, 80},
}));
// clang-format on
- auto expected = Literal::CreateR4FromArray4D<float>(expected_array);
+ auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
@@ -884,11 +884,11 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) {
}});
// clang-format on
- auto lhs_literal = Literal::CreateR4FromArray4D<float>(input);
+ auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(input);
HloInstruction* lhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
- auto rhs_literal = Literal::CreateR4FromArray4D<float>(weight);
+ auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(weight);
HloInstruction* rhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
rhs_instruction = b.AddInstruction(HloInstruction::CreateReverse(
@@ -933,7 +933,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) {
Array4D<float> expected_array({{{{2514, 2685}}}});
Array4D<float> expected_array_bf16({{{{2512, 2672}}}});
// clang-format on
- auto expected = Literal::CreateR4FromArray4D<float>(
+ auto expected = LiteralUtil::CreateR4FromArray4D<float>(
use_bfloat16_ ? expected_array_bf16 : expected_array);
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
@@ -964,11 +964,11 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) {
}});
// clang-format on
- auto lhs_literal = Literal::CreateR4FromArray4D<float>(input);
+ auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(input);
HloInstruction* lhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
- auto rhs_literal = Literal::CreateR4FromArray4D<float>(weight);
+ auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(weight);
HloInstruction* rhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
@@ -1010,7 +1010,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) {
Array4D<float> expected_array({{{{2514, 2685}}}});
Array4D<float> expected_array_bf16({{{{2512, 2672}}}});
// clang-format on
- auto expected = Literal::CreateR4FromArray4D<float>(
+ auto expected = LiteralUtil::CreateR4FromArray4D<float>(
use_bfloat16_ ? expected_array_bf16 : expected_array);
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
@@ -1028,7 +1028,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) {
{13, 14, 15, 16},
}));
// clang-format on
- auto lhs_literal = Literal::CreateR4FromArray4D<float>(lhs_array);
+ auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(lhs_array);
HloInstruction* lhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
@@ -1039,7 +1039,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) {
{7, 8},
}));
// clang-format on
- auto rhs_literal = Literal::CreateR4FromArray4D<float>(rhs_array);
+ auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(rhs_array);
HloInstruction* rhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
@@ -1074,7 +1074,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) {
{91, 112, 98, 120, 105, 128, 112},
{65, 84, 70, 90, 75, 96, 80},
}));
- auto expected = Literal::CreateR4FromArray4D<float>(expected_array);
+ auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
@@ -1091,7 +1091,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) {
{13, 14, 15, 16},
}));
// clang-format on
- auto lhs_literal = Literal::CreateR4FromArray4D<float>(lhs_array);
+ auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(lhs_array);
HloInstruction* lhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
@@ -1102,7 +1102,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) {
{7, 8},
}));
// clang-format on
- auto rhs_literal = Literal::CreateR4FromArray4D<float>(rhs_array);
+ auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(rhs_array);
HloInstruction* rhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
@@ -1138,7 +1138,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) {
{104, 91, 112, 98, 120, 105, 128, 112},
{78, 65, 84, 70, 90, 75, 96, 80},
}));
- auto expected = Literal::CreateR4FromArray4D<float>(expected_array);
+ auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
@@ -1156,7 +1156,7 @@ TEST_P(HloEvaluatorTest,
{13, 14, 15, 16},
}));
// clang-format on
- auto lhs_literal = Literal::CreateR4FromArray4D<float>(lhs_array);
+ auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(lhs_array);
HloInstruction* lhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
@@ -1167,7 +1167,7 @@ TEST_P(HloEvaluatorTest,
{8, 9, 10},
}));
// clang-format on
- auto rhs_literal = Literal::CreateR4FromArray4D<float>(rhs_array);
+ auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(rhs_array);
HloInstruction* rhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
@@ -1210,7 +1210,7 @@ TEST_P(HloEvaluatorTest,
{0, 0, 0},
{91, 98, 105},
}));
- auto expected = Literal::CreateR4FromArray4D<float>(expected_array);
+ auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
@@ -1225,9 +1225,9 @@ TEST_F(HloEvaluatorPreciseReduceTest, AddReductionPrecisionTest) {
constexpr int kNumElements = 1 << 25; // float += 1 saturates at 1<<24
std::vector<float> v(kNumElements, 1.0f);
HloInstruction* arg_instruction = b.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>(v)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(v)));
HloInstruction* init_value = b.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
HloComputation::Builder add_computation("add");
Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
@@ -1262,9 +1262,9 @@ void BM_ReducePrecisely(int num_iters) {
constexpr int kNumElements = 1 << 25; // float += 1 saturates at 1<<24
std::vector<float> v(kNumElements, 1.0f);
HloInstruction* arg_instruction = b.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>(v)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(v)));
auto init_value = b.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
HloComputation::Builder add_computation("add");
Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
@@ -1299,13 +1299,13 @@ TEST_P(HloEvaluatorTest, ReduceAdd) {
// }
auto arg_array = MakeUnique<Array2D<float>>(2, 3);
arg_array->FillUnique(1.0f);
- auto arg_literal = Literal::CreateR2FromArray2D<float>(*arg_array);
+ auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array);
HloInstruction* arg_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
auto init_value = b.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
HloComputation::Builder add_computation("add");
Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
@@ -1326,7 +1326,7 @@ TEST_P(HloEvaluatorTest, ReduceAdd) {
std::unique_ptr<Literal> result = Evaluate();
- auto expected = Literal::CreateR1<float>({6, 18});
+ auto expected = LiteralUtil::CreateR1<float>({6, 18});
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
@@ -1341,13 +1341,13 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) {
// }
auto arg_array = MakeUnique<Array2D<float>>(2, 3);
arg_array->FillUnique(1.0f);
- auto arg_literal = Literal::CreateR2FromArray2D<float>(*arg_array);
+ auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array);
HloInstruction* arg_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
auto init_value = b.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
HloComputation::Builder max_computation("max");
Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
@@ -1378,7 +1378,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) {
std::unique_ptr<Literal> result = Evaluate();
- auto expected = Literal::CreateR2<float>({{6, 7}});
+ auto expected = LiteralUtil::CreateR2<float>({{6, 7}});
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
@@ -1392,13 +1392,13 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) {
// }
auto arg_array = MakeUnique<Array2D<float>>(2, 3);
arg_array->FillUnique(1.0f);
- auto arg_literal = Literal::CreateR2FromArray2D<float>(*arg_array);
+ auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array);
HloInstruction* arg_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
auto init_value = b.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
HloComputation::Builder add_computation("add");
Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
@@ -1435,7 +1435,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) {
std::unique_ptr<Literal> result = Evaluate();
- auto expected = Literal::CreateR2<float>({{1, 3, 5}, {5, 11, 13}});
+ auto expected = LiteralUtil::CreateR2<float>({{1, 3, 5}, {5, 11, 13}});
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
@@ -1445,13 +1445,13 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) {
// arg: f32[4,4,4,4,4,4] full of ones. Using small dims to limit run-time.
std::vector<int64> input_dims(6, 4);
std::unique_ptr<Literal> arg_literal =
- Literal::CreateFullWithDescendingLayout<float>(input_dims, 1.0f);
+ LiteralUtil::CreateFullWithDescendingLayout<float>(input_dims, 1.0f);
HloInstruction* arg_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
auto init_value = b.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
HloComputation::Builder add_computation("add");
Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
@@ -1498,7 +1498,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) {
std::vector<int64> output_dims = {4, 3, 3, 3, 4, 4};
std::unique_ptr<Literal> result_literal =
- Literal::CreateFullWithDescendingLayout<float>(output_dims, 8.0f);
+ LiteralUtil::CreateFullWithDescendingLayout<float>(output_dims, 8.0f);
EXPECT_TRUE(LiteralTestUtil::Equal(*result_literal, *result));
}
@@ -1513,7 +1513,8 @@ TEST_P(HloEvaluatorTest, StridedSlice) {
// }
auto operand_array = MakeUnique<Array2D<float>>(3, 5);
operand_array->FillUnique(1.0f);
- auto operand_literal = Literal::CreateR2FromArray2D<float>(*operand_array);
+ auto operand_literal =
+ LiteralUtil::CreateR2FromArray2D<float>(*operand_array);
HloInstruction* operand = b.AddInstruction(
HloInstruction::CreateConstant(std::move(operand_literal)));
@@ -1527,7 +1528,7 @@ TEST_P(HloEvaluatorTest, StridedSlice) {
std::unique_ptr<Literal> result = Evaluate();
- auto expected = Literal::CreateR2<float>({
+ auto expected = LiteralUtil::CreateR2<float>({
{3},
{19},
});
@@ -1545,13 +1546,14 @@ TEST_P(HloEvaluatorTest, DynamicSlice) {
// }
auto operand_array = MakeUnique<Array2D<float>>(2, 4);
operand_array->FillUnique(1.0f);
- auto operand_literal = Literal::CreateR2FromArray2D<float>(*operand_array);
+ auto operand_literal =
+ LiteralUtil::CreateR2FromArray2D<float>(*operand_array);
HloInstruction* operand = b.AddInstruction(
HloInstruction::CreateConstant(std::move(operand_literal)));
auto start_indices = b.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int32>({0, 1})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({0, 1})));
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
b.AddInstruction(HloInstruction::CreateDynamicSlice(shape, operand,
@@ -1560,7 +1562,7 @@ TEST_P(HloEvaluatorTest, DynamicSlice) {
std::unique_ptr<Literal> result = Evaluate();
- auto expected = Literal::CreateR2<float>({
+ auto expected = LiteralUtil::CreateR2<float>({
{2, 3, 4},
{6, 7, 8},
});
@@ -1580,13 +1582,14 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) {
// }
auto operand_array = MakeUnique<Array2D<float>>(2, 4);
operand_array->FillUnique(1.0f);
- auto operand_literal = Literal::CreateR2FromArray2D<float>(*operand_array);
+ auto operand_literal =
+ LiteralUtil::CreateR2FromArray2D<float>(*operand_array);
HloInstruction* operand = b.AddInstruction(
HloInstruction::CreateConstant(std::move(operand_literal)));
auto start_indices = b.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int32>({2, 1})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({2, 1})));
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
b.AddInstruction(HloInstruction::CreateDynamicSlice(shape, operand,
@@ -1595,7 +1598,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) {
std::unique_ptr<Literal> result = Evaluate();
- auto expected = Literal::CreateR2<float>({
+ auto expected = LiteralUtil::CreateR2<float>({
{2, 3, 4},
{6, 7, 8},
});
@@ -1613,16 +1616,17 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) {
// }
auto operand_array = MakeUnique<Array2D<double>>(2, 3);
operand_array->FillUnique(1.0);
- auto operand_literal = Literal::CreateR2FromArray2D<double>(*operand_array);
+ auto operand_literal =
+ LiteralUtil::CreateR2FromArray2D<double>(*operand_array);
HloInstruction* operand = b.AddInstruction(
HloInstruction::CreateConstant(std::move(operand_literal)));
auto start_indices = b.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int64>({0, 1})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int64>({0, 1})));
auto update = b.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<double>({{-2.0, -3.0}, {-6.0, -7.0}})));
+ LiteralUtil::CreateR2<double>({{-2.0, -3.0}, {-6.0, -7.0}})));
Shape shape = ShapeUtil::MakeShape(F64, {2, 3});
b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
@@ -1631,7 +1635,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) {
std::unique_ptr<Literal> result = Evaluate();
- auto expected = Literal::CreateR2<double>({
+ auto expected = LiteralUtil::CreateR2<double>({
{1, -2, -3},
{5, -6, -7},
});
@@ -1649,12 +1653,13 @@ TEST_P(HloEvaluatorTest, SetAndGetTuples) {
// }
auto operand_array = MakeUnique<Array2D<double>>(2, 3);
operand_array->FillUnique(1.0);
- auto operand_literal2 = Literal::CreateR2FromArray2D<double>(*operand_array);
+ auto operand_literal2 =
+ LiteralUtil::CreateR2FromArray2D<double>(*operand_array);
HloInstruction* operand2 = b.AddInstruction(
HloInstruction::CreateConstant(std::move(operand_literal2)));
HloInstruction* operand1 = b.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int64>({0, 1})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int64>({0, 1})));
auto tuple =
b.AddInstruction(HloInstruction::CreateTuple({operand1, operand2}));
@@ -1666,7 +1671,7 @@ TEST_P(HloEvaluatorTest, SetAndGetTuples) {
std::unique_ptr<Literal> result = Evaluate();
- auto expected = Literal::CreateR2<double>({
+ auto expected = LiteralUtil::CreateR2<double>({
{1, 2, 3},
{5, 6, 7},
});
@@ -1686,9 +1691,9 @@ TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) {
operand_array->FillUnique(1.0);
HloInstruction* operand2 = b.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2FromArray2D<double>(*operand_array)));
+ LiteralUtil::CreateR2FromArray2D<double>(*operand_array)));
HloInstruction* operand1 = b.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int64>({0, 1})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int64>({0, 1})));
auto tuple1 =
b.AddInstruction(HloInstruction::CreateTuple({operand1, operand2}));
@@ -1706,8 +1711,8 @@ TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) {
std::unique_ptr<Literal> result = Evaluate();
auto result_inner_literal =
- Literal::CreateR2FromArray2D<double>(*operand_array);
- auto expected = Literal::MakeTuple({
+ LiteralUtil::CreateR2FromArray2D<double>(*operand_array);
+ auto expected = LiteralUtil::MakeTuple({
result_inner_literal.get(),
result_inner_literal.get(),
});
@@ -1735,7 +1740,7 @@ TEST_P(HloEvaluatorTest, Reverse) {
{{23.0f}, {24.0f}}},
});
// clang-format on
- auto operand_literal = Literal::CreateR4FromArray4D<float>(input);
+ auto operand_literal = LiteralUtil::CreateR4FromArray4D<float>(input);
HloInstruction* operand = b.AddInstruction(
HloInstruction::CreateConstant(std::move(operand_literal)));
@@ -1746,7 +1751,7 @@ TEST_P(HloEvaluatorTest, Reverse) {
std::unique_ptr<Literal> result = Evaluate();
// clang-format off
- auto expected = Literal::CreateR4FromArray4D<float>({
+ auto expected = LiteralUtil::CreateR4FromArray4D<float>({
{{{23.0f}, {24.0f}},
{{21.0f}, {22.0f}},
{{19.0f}, {20.0f}}},
@@ -1782,11 +1787,11 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) {
// Evaluate add with param0 = {1, 2, 3, 4}, square = {10, 20, 30, 40}.
HloEvaluator evaluator;
auto result = evaluator.EvaluateWithSubstitutions(
- add, {{param0, Literal::CreateR1<float>({1, 2, 3, 4}).get()},
- {square, Literal::CreateR1<float>({10, 20, 30, 40}).get()}});
+ add, {{param0, LiteralUtil::CreateR1<float>({1, 2, 3, 4}).get()},
+ {square, LiteralUtil::CreateR1<float>({10, 20, 30, 40}).get()}});
TF_ASSERT_OK(result.status());
EXPECT_TRUE(LiteralTestUtil::Equal(
- *Literal::CreateR1<float>({11, 22, 33, 44}), *result.ValueOrDie()));
+ *LiteralUtil::CreateR1<float>({11, 22, 33, 44}), *result.ValueOrDie()));
}
// Check that EvaluateWithSubstitutions works if one of the operands to the op
@@ -1799,18 +1804,18 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutionsWithConstantOperand) {
b.AddInstruction(HloInstruction::CreateParameter(0, shape, "param0"));
HloInstruction* square = b.AddInstruction(HloInstruction::CreateBinary(
shape, HloOpcode::kMultiply, param0, param0));
- HloInstruction* constant = b.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>({1, 2, 3, 4})));
+ HloInstruction* constant = b.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<float>({1, 2, 3, 4})));
HloInstruction* add = b.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, constant, square));
// Evaluate add with square = {10, 20, 30, 40}.
HloEvaluator evaluator;
auto result = evaluator.EvaluateWithSubstitutions(
- add, {{square, Literal::CreateR1<float>({10, 20, 30, 40}).get()}});
+ add, {{square, LiteralUtil::CreateR1<float>({10, 20, 30, 40}).get()}});
TF_ASSERT_OK(result.status());
EXPECT_TRUE(LiteralTestUtil::Equal(
- *Literal::CreateR1<float>({11, 22, 33, 44}), *result.ValueOrDie()));
+ *LiteralUtil::CreateR1<float>({11, 22, 33, 44}), *result.ValueOrDie()));
}
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) {
@@ -1830,11 +1835,12 @@ ENTRY main {
)";
ParseAndVerifyModule(hlo_text);
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2});
- EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}}),
- *Evaluate({operand.get(), gather_indices.get()})));
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> gather_indices =
+ LiteralUtil::CreateR1<int32>({0, 2});
+ EXPECT_TRUE(LiteralTestUtil::Equal(
+ *LiteralUtil::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}}),
+ *Evaluate({operand.get(), gather_indices.get()})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) {
@@ -1854,10 +1860,11 @@ ENTRY main {
)";
ParseAndVerifyModule(hlo_text);
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2});
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> gather_indices =
+ LiteralUtil::CreateR1<int32>({0, 2});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *Literal::CreateR2<int32>({{1, 3}, {4, 6}, {7, 9}}),
+ *LiteralUtil::CreateR2<int32>({{1, 3}, {4, 6}, {7, 9}}),
*Evaluate({operand.get(), gather_indices.get()})));
}
@@ -1878,11 +1885,11 @@ ENTRY main {
)";
ParseAndVerifyModule(hlo_text);
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> gather_indices =
- Literal::CreateR2<int32>({{0, 2}, {2, 1}});
+ LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *Literal::CreateR3<int32>(
+ *LiteralUtil::CreateR3<int32>(
{{{1, 3}, {4, 6}, {7, 9}}, {{3, 2}, {6, 5}, {9, 8}}}),
*Evaluate({operand.get(), gather_indices.get()})));
}
@@ -1904,13 +1911,13 @@ ENTRY main {
)";
ParseAndVerifyModule(hlo_text);
std::unique_ptr<Literal> operand =
- Literal::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
- {{-4, 4}, {-5, 5}, {-6, 6}}, //
- {{-7, 7}, {-8, 8}, {-9, 9}}});
+ LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
+ {{-4, 4}, {-5, 5}, {-6, 6}}, //
+ {{-7, 7}, {-8, 8}, {-9, 9}}});
std::unique_ptr<Literal> gather_indices =
- Literal::CreateR2<int32>({{0, 0}, {1, 0}});
+ LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{-1, 1}, {-4, 4}}),
+ LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{-1, 1}, {-4, 4}}),
*Evaluate({operand.get(), gather_indices.get()})));
}
@@ -1932,13 +1939,13 @@ ENTRY main {
)";
ParseAndVerifyModule(hlo_text);
std::unique_ptr<Literal> operand =
- Literal::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
- {{-4, 4}, {-5, 5}, {-6, 6}}, //
- {{-7, 7}, {-8, 8}, {-9, 9}}});
+ LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
+ {{-4, 4}, {-5, 5}, {-6, 6}}, //
+ {{-7, 7}, {-8, 8}, {-9, 9}}});
std::unique_ptr<Literal> gather_indices =
- Literal::CreateR2<int32>({{0, 0}, {1, 0}});
+ LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{-2, 2}, {-1, 1}}),
+ LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{-2, 2}, {-1, 1}}),
*Evaluate({operand.get(), gather_indices.get()})));
}
@@ -1959,10 +1966,11 @@ ENTRY main {
)";
ParseAndVerifyModule(hlo_text);
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({1, 1});
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> gather_indices =
+ LiteralUtil::CreateR1<int32>({1, 1});
EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{5}}),
+ LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{5}}),
*Evaluate({operand.get(), gather_indices.get()})));
}
@@ -1983,11 +1991,11 @@ ENTRY main {
)";
ParseAndVerifyModule(hlo_text);
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> gather_indices =
- Literal::CreateR2<int32>({{2, 1}, {1, 1}});
+ LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR3<int32>({{{8}}, {{5}}}),
+ LiteralTestUtil::Equal(*LiteralUtil::CreateR3<int32>({{{8}}, {{5}}}),
*Evaluate({operand.get(), gather_indices.get()})));
}
@@ -2007,10 +2015,11 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand = Literal::CreateR2<int32>({{}, {}, {}});
- std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2});
+ std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
+ std::unique_ptr<Literal> gather_indices =
+ LiteralUtil::CreateR1<int32>({0, 2});
EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{}, {}}),
+ LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{}, {}}),
*Evaluate({operand.get(), gather_indices.get()})));
}
@@ -2031,11 +2040,11 @@ ENTRY main {
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand = Literal::CreateR1<int32>({0, 1, 2});
+ std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
std::unique_ptr<Literal> gather_indices =
- Literal::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}});
+ LiteralUtil::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}});
EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{0, 1}, {2, 1}}),
+ LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{0, 1}, {2, 1}}),
*Evaluate({operand.get(), gather_indices.get()})));
}
@@ -2043,14 +2052,14 @@ ENTRY main {
// element-wise comparison with 2 bfloat16 operands.
TEST_P(HloEvaluatorTest, DoesCompareBF16) {
// lhs >= rhs
- auto lhs = Literal::CreateR2<bfloat16>(
+ auto lhs = LiteralUtil::CreateR2<bfloat16>(
{{bfloat16(0.25), bfloat16(0.35), bfloat16(0.125)},
{bfloat16(-0.25), bfloat16(-0.35), bfloat16(-0.125)}});
- auto rhs = Literal::CreateR2<bfloat16>(
+ auto rhs = LiteralUtil::CreateR2<bfloat16>(
{{bfloat16(0.5), bfloat16(0.125), bfloat16(0.125)},
{bfloat16(0.25), bfloat16(-0.375), bfloat16(-0.127)}});
auto expected =
- Literal::CreateR2<bool>({{false, true, true}, {false, true, true}});
+ LiteralUtil::CreateR2<bool>({{false, true, true}, {false, true, true}});
TestBinaryOp(HloOpcode::kGe, std::move(expected), std::move(lhs),
std::move(rhs));
}
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index cdbac74ba4..2ae5f8bf36 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_
+#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/core/lib/core/casts.h"
@@ -1316,7 +1317,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
parent_->GetEvaluatedLiteralFor(operand);
auto curr_val = arg_literal.Get<NativeT>(multi_index);
- auto curr_val_literal = Literal::CreateR0<NativeT>(curr_val);
+ auto curr_val_literal = LiteralUtil::CreateR0<NativeT>(curr_val);
arg_literals.push_back(std::move(curr_val_literal));
}
@@ -1504,8 +1505,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
auto curr_val = arg_literal.Get<ReturnT>(input_index);
// Evaluate computation with specified literal operands.
- auto curr_val_literal = Literal::CreateR0<ReturnT>(curr_val);
- auto result_val_literal = Literal::CreateR0<ReturnT>(result_val);
+ auto curr_val_literal = LiteralUtil::CreateR0<ReturnT>(curr_val);
+ auto result_val_literal =
+ LiteralUtil::CreateR0<ReturnT>(result_val);
std::unique_ptr<Literal> computed_result =
embedded_evaluator
@@ -1583,10 +1585,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// Used in the dual IterateThroughWindow lambdas below. Hoisted to avoid
// dynamic memory allocations.
- auto curr_val_literal = Literal::CreateR0<ReturnT>(ReturnT());
- auto selected_val_literal = Literal::CreateR0<ReturnT>(ReturnT());
- auto source_literal_scatter = Literal::CreateR0<ReturnT>(ReturnT());
- auto scattered_literal = Literal::CreateR0<ReturnT>(ReturnT());
+ auto curr_val_literal = LiteralUtil::CreateR0<ReturnT>(ReturnT());
+ auto selected_val_literal = LiteralUtil::CreateR0<ReturnT>(ReturnT());
+ auto source_literal_scatter = LiteralUtil::CreateR0<ReturnT>(ReturnT());
+ auto scattered_literal = LiteralUtil::CreateR0<ReturnT>(ReturnT());
do {
// For each element in `source`, we place a window in `operand`. For each
// window placement, we iterate inside the window twice:
@@ -1707,9 +1709,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// Evaluate computation with specified literal operands.
const auto curr_val_literal =
- Literal::CreateR0<ReturnT>(curr_val);
+ LiteralUtil::CreateR0<ReturnT>(curr_val);
const auto result_val_literal =
- Literal::CreateR0<ReturnT>(result_val);
+ LiteralUtil::CreateR0<ReturnT>(result_val);
std::unique_ptr<Literal> computed_result =
embedded_evaluator
.Evaluate<const Literal*>(
@@ -1754,7 +1756,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return operand_literal.Get<ReturnT>(operand_index);
};
- auto result = Literal::CreateFromDimensions(
+ auto result = LiteralUtil::CreateFromDimensions(
shape.element_type(), AsInt64Slice(shape.dimensions()));
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(func));
parent_->evaluated_[slice] = std::move(result);
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index 7a1372f929..57cf34d7de 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -27,7 +27,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc
index 68f41a1cbb..1d7a062c55 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
+#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@@ -120,7 +121,7 @@ TEST(HloGraphDumperTest, NestedFusion) {
TEST(HloGraphDumperTest, Constant) {
HloComputation::Builder b("b");
auto instruction = b.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(-42)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(-42)));
instruction->SetAndSanitizeName("i_am_a_constant_root_instruction");
HloModuleConfig config;
HloModule m(TestName(), config);
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 6ea302f8b4..19bee38790 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include <utility>
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
@@ -163,6 +163,20 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
proto.dimensions().end()),
computations(0));
break;
+ case HloOpcode::kSort: {
+ TF_RET_CHECK(proto.operand_ids_size() == 1 ||
+ proto.operand_ids_size() == 2)
+ << "Sort instruction should have 1 or 2 operands but has "
+ << proto.operand_ids_size();
+ TF_RET_CHECK(proto.dimensions().size() == 1)
+ << "Sort instruction should have 1 dimension";
+ HloInstruction* keys = operands(0);
+ HloInstruction* values =
+ proto.operand_ids_size() == 2 ? operands(1) : nullptr;
+ instruction =
+ CreateSort(proto.shape(), proto.dimensions(0), keys, values);
+ break;
+ }
case HloOpcode::kTranspose:
TF_RET_CHECK(proto.operand_ids_size() == 1)
<< "Transpose instruction should have 1 operand but sees "
@@ -271,7 +285,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
// converted to take tokens.
instruction = CreateInfeed(data_shape, proto.infeed_config());
} else {
- CHECK_EQ(proto.operand_ids_size(), 2);
+ CHECK_EQ(proto.operand_ids_size(), 1);
instruction =
CreateInfeed(data_shape, operands(0), proto.infeed_config());
}
@@ -372,6 +386,23 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
slice_sizes);
break;
}
+ case HloOpcode::kGather: {
+ TF_RET_CHECK(proto.operand_ids_size() == 2)
+ << "Gather instruction should have 2 operands but sees "
+ << proto.operand_ids_size();
+ TF_RET_CHECK(proto.has_gather_dimension_numbers())
+ << "Gather instruction should have GatherDimensionNumbers set.";
+ std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers =
+ MakeUnique<GatherDimensionNumbers>(proto.gather_dimension_numbers());
+ std::vector<int64> gather_window_bounds;
+ for (int64 bound : proto.gather_window_bounds()) {
+ gather_window_bounds.push_back(bound);
+ }
+ instruction =
+ CreateGather(proto.shape(), operands(0), operands(1),
+ *gather_dimension_numbers, gather_window_bounds);
+ break;
+ }
default: {
instruction = WrapUnique(new HloInstruction(opcode, proto.shape()));
for (const int64 operand_id : proto.operand_ids()) {
@@ -413,13 +444,6 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
instruction->set_sharding(sharding);
}
- if (proto.has_gather_dimension_numbers()) {
- instruction->gather_dimension_numbers_ =
- MakeUnique<GatherDimensionNumbers>(proto.gather_dimension_numbers());
- }
- for (int64 bound : proto.gather_window_bounds()) {
- instruction->gather_window_bounds_.push_back(bound);
- }
return std::move(instruction);
}
@@ -684,6 +708,7 @@ HloInstruction::CreateCrossReplicaSum(
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAfterAll(
tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
+ CHECK(!operands.empty());
auto instruction = WrapUnique(
new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape()));
for (auto operand : operands) {
@@ -692,6 +717,11 @@ HloInstruction::CreateCrossReplicaSum(
return instruction;
}
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateToken() {
+ return WrapUnique(
+ new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape()));
+}
+
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateWhile(
const Shape& shape, HloComputation* condition, HloComputation* body,
HloInstruction* init) {
@@ -909,13 +939,9 @@ HloInstruction::CreateBroadcastSequence(
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSort(
- const Shape& shape, HloInstruction* keys, HloInstruction* values) {
- auto instruction = WrapUnique(new HloInstruction(HloOpcode::kSort, shape));
- instruction->AppendOperand(keys);
- if (values) {
- instruction->AppendOperand(values);
- }
- return instruction;
+ const Shape& shape, int64 dimension, HloInstruction* keys,
+ HloInstruction* values) {
+ return MakeUnique<HloSortInstruction>(shape, dimension, keys, values);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
@@ -1020,34 +1046,8 @@ bool HloInstruction::HasSideEffect() const {
const Shape& shape, HloInstruction* operand, HloInstruction* gather_indices,
const GatherDimensionNumbers& gather_dim_numbers,
tensorflow::gtl::ArraySlice<int64> window_bounds) {
- std::unique_ptr<HloInstruction> instruction =
- WrapUnique(new HloInstruction(HloOpcode::kGather, shape));
- instruction->AppendOperand(operand);
- instruction->AppendOperand(gather_indices);
- instruction->gather_dimension_numbers_ =
- MakeUnique<GatherDimensionNumbers>(gather_dim_numbers);
- c_copy(window_bounds, std::back_inserter(instruction->gather_window_bounds_));
- return instruction;
-}
-
-/* static */ GatherDimensionNumbers HloInstruction::MakeGatherDimNumbers(
- tensorflow::gtl::ArraySlice<int64> output_window_dims,
- tensorflow::gtl::ArraySlice<int64> elided_window_dims,
- tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims,
- int64 index_vector_dim) {
- GatherDimensionNumbers gather_dim_numbers;
- for (int64 output_window_dim : output_window_dims) {
- gather_dim_numbers.add_output_window_dims(output_window_dim);
- }
- for (int64 elided_window_dim : elided_window_dims) {
- gather_dim_numbers.add_elided_window_dims(elided_window_dim);
- }
- for (int64 gather_dim_to_input_dim : gather_dims_to_operand_dims) {
- gather_dim_numbers.add_gather_dims_to_operand_dims(gather_dim_to_input_dim);
- }
-
- gather_dim_numbers.set_index_vector_dim(index_vector_dim);
- return gather_dim_numbers;
+ return MakeUnique<HloGatherInstruction>(shape, operand, gather_indices,
+ gather_dim_numbers, window_bounds);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDomain(
@@ -1110,6 +1110,8 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
case HloOpcode::kHostCompute:
case HloOpcode::kPad:
case HloOpcode::kDynamicSlice:
+ case HloOpcode::kSort:
+ case HloOpcode::kGather:
clone = CloneWithNewOperandsImpl(shape, new_operands, context);
break;
// Unary ops.
@@ -1211,11 +1213,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
true_computation(), new_operands[2],
false_computation());
break;
- case HloOpcode::kGather:
- CHECK_EQ(new_operands.size(), 2);
- clone = CreateGather(shape, new_operands[0], new_operands[1],
- *gather_dimension_numbers_, gather_window_bounds_);
- break;
case HloOpcode::kDomain:
CHECK_EQ(new_operands.size(), 1);
clone =
@@ -1223,15 +1220,11 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
user_side_metadata_->Clone());
break;
case HloOpcode::kAfterAll:
- clone = CreateAfterAll(new_operands);
- break;
- case HloOpcode::kSort:
- CHECK(new_operands.size() == 1 || new_operands.size() == 2)
- << "Too many operands for sort: " << new_operands.size();
- HloInstruction* keys = new_operands[0];
- HloInstruction* values =
- new_operands.size() == 2 ? new_operands[1] : nullptr;
- clone = CreateSort(shape, keys, values);
+ if (new_operands.empty()) {
+ clone = CreateToken();
+ } else {
+ clone = CreateAfterAll(new_operands);
+ }
break;
}
SetupDerivedInstruction(clone.get());
@@ -1509,7 +1502,6 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kShiftRightLogical:
case HloOpcode::kSign:
- case HloOpcode::kSort:
case HloOpcode::kSin:
case HloOpcode::kSubtract:
case HloOpcode::kTanh:
@@ -1518,7 +1510,6 @@ bool HloInstruction::IdenticalSlowPath(
return true;
// These opcodes have complex or special behavior so just return false.
- case HloOpcode::kDomain:
case HloOpcode::kWhile:
case HloOpcode::kAfterAll:
return false;
@@ -1528,11 +1519,6 @@ bool HloInstruction::IdenticalSlowPath(
return protobuf_util::ProtobufEquals(dot_dimension_numbers(),
other.dot_dimension_numbers());
- case HloOpcode::kGather:
- return protobuf_util::ProtobufEquals(gather_dimension_numbers(),
- other.gather_dimension_numbers()) &&
- gather_window_bounds() == other.gather_window_bounds();
-
// Remaining instructions with special values.
case HloOpcode::kCall:
return eq_computations(to_apply(), other.to_apply());
@@ -1540,6 +1526,10 @@ bool HloInstruction::IdenticalSlowPath(
return eq_computations(true_computation(), other.true_computation()) &&
eq_computations(false_computation(), other.false_computation());
+ case HloOpcode::kDomain:
+ return operand_side_metadata().Matches(other.operand_side_metadata()) &&
+ user_side_metadata().Matches(other.user_side_metadata());
+
// Ops migrated to subclasses should never come to this line.
// TODO(b/80131774): Remove this switch when migration is complete.
case HloOpcode::kBatchNormTraining:
@@ -1553,6 +1543,7 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kReverse:
case HloOpcode::kConcatenate:
case HloOpcode::kReduce:
+ case HloOpcode::kSort:
case HloOpcode::kTranspose:
case HloOpcode::kBroadcast:
case HloOpcode::kMap:
@@ -1574,6 +1565,7 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kHostCompute:
case HloOpcode::kPad:
case HloOpcode::kDynamicSlice:
+ case HloOpcode::kGather:
LOG(FATAL) << "Base class impl called for opcode with subclass: "
<< opcode();
}
@@ -1939,11 +1931,6 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
if (dot_dimension_numbers_ != nullptr) {
extra.push_back(DotDimensionNumbersToString());
}
- if (gather_dimension_numbers_ != nullptr) {
- extra.push_back(GatherDimensionNumbersToString());
- extra.push_back(
- StrCat("window_bounds={", Join(gather_window_bounds(), ","), "}"));
- }
if (options.print_subcomputation_mode() ==
HloPrintOptions::PrintSubcomputationMode::kNameOnly) {
@@ -2073,14 +2060,6 @@ HloInstructionProto HloInstruction::ToProto() const {
if (dot_dimension_numbers_ != nullptr) {
*proto.mutable_dot_dimension_numbers() = *dot_dimension_numbers_;
}
- if (gather_dimension_numbers_ != nullptr) {
- *proto.mutable_gather_dimension_numbers() = *gather_dimension_numbers_;
- }
- if (opcode() == HloOpcode::kGather) {
- for (int64 bound : gather_window_bounds()) {
- proto.add_gather_window_bounds(bound);
- }
- }
if (has_sharding()) {
*proto.mutable_sharding() = sharding().ToProto();
@@ -2841,26 +2820,6 @@ std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) {
return os << ToString(kind);
}
-string HloInstruction::GatherDimensionNumbersToString() const {
- CHECK_NE(gather_dimension_numbers_.get(), nullptr);
- string output_window_dims =
- StrCat("output_window_dims={",
- Join(gather_dimension_numbers_->output_window_dims(), ","), "}");
- string elided_window_dims =
- StrCat("elided_window_dims={",
- Join(gather_dimension_numbers_->elided_window_dims(), ","), "}");
- string gather_dims_to_operand_dims = StrCat(
- "gather_dims_to_operand_dims={",
- Join(gather_dimension_numbers_->gather_dims_to_operand_dims(), ","), "}");
- string index_vector_dim = StrCat(
- "index_vector_dim=", gather_dimension_numbers_->index_vector_dim());
-
- return Join<std::initializer_list<string>>(
- {output_window_dims, elided_window_dims, gather_dims_to_operand_dims,
- index_vector_dim},
- ", ");
-}
-
bool HloInstruction::CouldBeBitcast() const {
switch (opcode_) {
case HloOpcode::kTranspose:
@@ -3174,4 +3133,14 @@ int64 HloInstruction::slice_sizes(int64 dimension) const {
const std::vector<int64>& HloInstruction::dynamic_slice_sizes() const {
return Cast<HloDynamicSliceInstruction>(this)->dynamic_slice_sizes();
}
+
+const GatherDimensionNumbers& HloInstruction::gather_dimension_numbers() const {
+ return Cast<HloGatherInstruction>(this)->gather_dimension_numbers();
+}
+
+tensorflow::gtl::ArraySlice<int64> HloInstruction::gather_window_bounds()
+ const {
+ return Cast<HloGatherInstruction>(this)->gather_window_bounds();
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 34e7dcb43d..cbd78fa124 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -33,7 +33,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/iterator_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
@@ -615,7 +615,7 @@ class HloInstruction {
// Creates a sort op, with a keys operand, and an optional values operand.
static std::unique_ptr<HloInstruction> CreateSort(
- const Shape& shape, HloInstruction* keys,
+ const Shape& shape, int64 dimension, HloInstruction* keys,
HloInstruction* values = nullptr);
// Creates a while instruction, given a condition computation, a body
@@ -687,17 +687,18 @@ class HloInstruction {
const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> dimensions);
- // Creates a token instruction used for joining or creating new values of
- // token type which thread through side-effecting operations.
+ // Creates a Afterall instruction used for joining or creating new values of
+ // token type which thread through side-effecting operations. Operands must
+ // all be tokens, and there must be at least one operand.
static std::unique_ptr<HloInstruction> CreateAfterAll(
tensorflow::gtl::ArraySlice<HloInstruction*> operands);
- // Creates an instance of GatherDimensionNumbers.
- static GatherDimensionNumbers MakeGatherDimNumbers(
- tensorflow::gtl::ArraySlice<int64> output_window_dims,
- tensorflow::gtl::ArraySlice<int64> elided_window_dims,
- tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims,
- int64 index_vector_dim);
+ // Creates an AfterAll instruction which creates a token type out of thin air
+ // (no operands). This is a separate method from CreateAfterAll to facility
+ // the removal of operand-less AfterAll instructions.
+ // TODO(b/110532604): Remove this capability of creating a token from nothing
+ // when we plumb a primordial token from the entry computation.
+ static std::unique_ptr<HloInstruction> CreateToken();
// Returns the opcode for this instruction.
HloOpcode opcode() const { return opcode_; }
@@ -1073,19 +1074,6 @@ class HloInstruction {
// Returns the dump string of the dot dimension numbers.
string DotDimensionNumbersToString() const;
- const GatherDimensionNumbers& gather_dimension_numbers() const {
- CHECK(gather_dimension_numbers_ != nullptr);
- return *gather_dimension_numbers_;
- }
-
- tensorflow::gtl::ArraySlice<int64> gather_window_bounds() const {
- CHECK_EQ(opcode(), HloOpcode::kGather);
- return gather_window_bounds_;
- }
-
- // Returns the dump string of the gather dimension numbers.
- string GatherDimensionNumbersToString() const;
-
// Clones the HLO instruction. The clone will have the same opcode, shape, and
// operands. After creation the clone has no uses. "this" (the instruction
// cloned from) is not changed. Suffix is the string to append to the name of
@@ -1452,6 +1440,12 @@ class HloInstruction {
// Delegates to HloDynamicSliceInstruction::dynamic_slice_sizes.
const std::vector<int64>& dynamic_slice_sizes() const;
+
+ // Delegates to HloGatherInstruction::gather_dimension_numbers.
+ const GatherDimensionNumbers& gather_dimension_numbers() const;
+ // Delegates to HloGatherInstruction::gather_window_bounds.
+ tensorflow::gtl::ArraySlice<int64> gather_window_bounds() const;
+
// Old methods kept for smooth subclassing transition END.
protected:
@@ -1595,9 +1589,6 @@ class HloInstruction {
// Describes the dimension numbers used for a dot.
std::unique_ptr<DotDimensionNumbers> dot_dimension_numbers_;
- std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_;
- std::vector<int64> gather_window_bounds_;
-
// Used to tag kCopy instructions that are eligible for copy elision.
bool copy_elision_allowed_ = true;
diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
index d8ca99dfd1..b75a2bd34b 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
@@ -20,10 +20,11 @@ limitations under the License.
#include <utility>
#include <vector>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
@@ -249,7 +250,7 @@ TEST_F(HloInstructionTest, MultipleUsersAndOperands) {
auto param1 = builder.AddInstruction(
HloInstruction::CreateParameter(1, r0f32_, "param1"));
auto c0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
auto addleft = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param0, c0));
auto addright = builder.AddInstruction(
@@ -294,7 +295,7 @@ TEST_F(HloInstructionTest, MultipleUsersAndOperandsWithUnaryOps) {
auto param1 = builder.AddInstruction(
HloInstruction::CreateParameter(1, r0f32_, "param1"));
auto c0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
auto neg1 = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, c0));
auto addleft = builder.AddInstruction(
@@ -334,7 +335,7 @@ TEST_F(HloInstructionTest, TrivialMap) {
auto param = embedded_builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32, "x"));
auto value = embedded_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
embedded_builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param, value));
auto add_f32 = module->AddEmbeddedComputation(embedded_builder.Build());
@@ -383,9 +384,9 @@ TEST_F(HloInstructionTest, TrivialReduce) {
auto param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, f32a100x10, "p"));
auto const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
auto reduce = builder.AddInstruction(
HloInstruction::CreateReduce(f32v100, param0, const0,
/*dimensions_to_reduce=*/{1}, add_f32));
@@ -626,7 +627,7 @@ TEST_F(HloInstructionTest, SingletonFusionOp) {
HloComputation::Builder builder(TestName());
// Create a fusion instruction containing a single unary operation.
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
auto exp = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant));
auto module = CreateNewModule();
@@ -642,9 +643,9 @@ TEST_F(HloInstructionTest, BinaryFusionOp) {
HloComputation::Builder builder(TestName());
// Create a fusion instruction containing a single binary operation.
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.1f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.1f)));
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
r0f32_, HloOpcode::kAdd, constant1, constant2));
auto module = CreateNewModule();
@@ -661,7 +662,7 @@ TEST_F(HloInstructionTest, ChainFusionOp) {
HloComputation::Builder builder(TestName());
// Create a chain of fused unary ops.
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
auto exp1 = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant));
auto exp2 = builder.AddInstruction(
@@ -682,7 +683,7 @@ TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) {
HloComputation::Builder builder(TestName());
// Create a chain of fused unary ops.
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
auto exp1 = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant));
auto exp2 = builder.AddInstruction(
@@ -710,13 +711,13 @@ TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) {
TEST_F(HloInstructionTest, PreserveOutfeedShapeThroughClone) {
HloComputation::Builder builder(TestName());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2<float>({
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>({
{1, 2},
{3, 4},
})));
auto shape10 = ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0});
auto shape01 = ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1});
- auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({}));
+ auto token = builder.AddInstruction(HloInstruction::CreateToken());
auto outfeed10 = builder.AddInstruction(
HloInstruction::CreateOutfeed(shape10, constant, token, ""));
auto outfeed01 = builder.AddInstruction(
@@ -732,7 +733,7 @@ TEST_F(HloInstructionTest, PreserveOutfeedShapeThroughClone) {
TEST_F(HloInstructionTest, PreserveTupleShapeThroughClone) {
HloComputation::Builder builder(TestName());
auto* constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2<float>({
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>({
{1, 2},
{3, 4},
})));
@@ -763,7 +764,7 @@ TEST_F(HloInstructionTest, FusionOpWithCalledComputations) {
HloComputation::Builder builder(TestName());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
auto map_1_x = builder.AddInstruction(
HloInstruction::CreateMap(scalar_shape, {constant}, computation_x));
auto map_2_x = builder.AddInstruction(
@@ -798,11 +799,11 @@ TEST_F(HloInstructionTest, ComplexFusionOp) {
// Notable complexities are repeated operands in the same instruction,
// different shapes, use of value in different expressions.
auto c1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
auto c2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.1f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.1f)));
auto c3 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(9.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(9.0f)));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, c1, c2));
@@ -873,11 +874,11 @@ TEST_F(HloInstructionTest, IdenticalInstructions) {
// Create a set of random constant operands to use below. Make them matrices
// so dimensions are interesting.
auto operand1 = HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}));
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}));
auto operand2 = HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{10.0, 20.0}, {30.0, 40.0}}));
- auto vector_operand =
- HloInstruction::CreateConstant(Literal::CreateR1<float>({42.0, 123.0}));
+ LiteralUtil::CreateR2<float>({{10.0, 20.0}, {30.0, 40.0}}));
+ auto vector_operand = HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<float>({42.0, 123.0}));
Shape shape = operand1->shape();
// Convenient short names for the operands.
@@ -1234,9 +1235,9 @@ TEST_F(HloInstructionTest, NestedFusionEquality) {
// Build a nested fusion computation.
Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
auto a = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1.0, 0.0}, {0.0, 1.0}})));
+ LiteralUtil::CreateR2<float>({{1.0, 0.0}, {0.0, 1.0}})));
auto b = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
+ LiteralUtil::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
auto b_t = builder.AddInstruction(
HloInstruction::CreateTranspose(data_shape, b, {1, 0}));
DotDimensionNumbers dot_dnums;
@@ -1245,7 +1246,7 @@ TEST_F(HloInstructionTest, NestedFusionEquality) {
auto dot = builder.AddInstruction(
HloInstruction::CreateDot(data_shape, a, b_t, dot_dnums));
auto one = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto add_operand = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape, one, {1}));
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
@@ -1342,7 +1343,7 @@ TEST_F(HloInstructionTest, Stringification) {
"condition=%TransposeDot, body=%TransposeDot");
auto pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
HloInstruction* conditional =
builder.AddInstruction(HloInstruction::CreateConditional(
sout, pred, x, computation, x, computation));
@@ -1369,7 +1370,7 @@ TEST_F(HloInstructionTest, StringifyGather_0) {
HloInstruction* gather_instruction =
builder.AddInstruction(HloInstruction::CreateGather(
gather_result_shape, input, gather_indices,
- HloInstruction::MakeGatherDimNumbers(
+ HloGatherInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 7, 8},
/*elided_window_dims=*/{},
/*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
@@ -1405,7 +1406,7 @@ TEST_F(HloInstructionTest, StringifyGather_1) {
HloInstruction* gather_instruction =
builder.AddInstruction(HloInstruction::CreateGather(
gather_result_shape, input, gather_indices,
- HloInstruction::MakeGatherDimNumbers(
+ HloGatherInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 7, 8},
/*elided_window_dims=*/{},
/*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
@@ -1455,15 +1456,15 @@ TEST_F(HloInstructionTest, CanonnicalStringificationFusion) {
HloInstruction* fusion = computation->CreateFusionInstruction(
{dot, reshape}, HloInstruction::FusionKind::kLoop);
- EXPECT_EQ(
- fusion->ToString(options),
+ const string expected_fusion =
R"(f32[5,20]{1,0} fusion(f32[5,10]{1,0}, f32[20,10]{1,0}), kind=kLoop, calls=
{
tmp_0 = f32[5,10]{1,0} parameter(0)
tmp_1 = f32[20,10]{1,0} parameter(1)
tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})");
+})";
+ EXPECT_EQ(fusion->ToString(options), expected_fusion);
}
TEST_F(HloInstructionTest, CanonnicalStringificationWhile) {
@@ -1495,8 +1496,8 @@ TEST_F(HloInstructionTest, CanonnicalStringificationWhile) {
HloInstruction::CreateWhile(sout, computation, computation, x));
auto options = HloPrintOptions().Canonical();
- EXPECT_EQ(loop->ToString(options),
- R"(f32[5,20]{1,0} while(f32[5,10]{1,0}), condition=
+ const string expected_loop =
+ R"(f32[5,20]{1,0} while(f32[5,10]{1,0}), condition=
{
tmp_0 = f32[5,10]{1,0} parameter(0)
tmp_1 = f32[20,10]{1,0} parameter(1)
@@ -1518,7 +1519,8 @@ TEST_F(HloInstructionTest, CanonnicalStringificationWhile) {
tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
-})");
+})";
+ EXPECT_EQ(loop->ToString(options), expected_loop);
}
TEST_F(HloInstructionTest, CanonnicalStringificationConditional) {
@@ -1550,13 +1552,12 @@ TEST_F(HloInstructionTest, CanonnicalStringificationConditional) {
HloInstruction::CreateWhile(sout, computation, computation, x));
auto pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
HloInstruction* conditional =
builder.AddInstruction(HloInstruction::CreateConditional(
sout, pred, x, computation, x, computation));
auto options = HloPrintOptions().Canonical();
- EXPECT_EQ(
- conditional->ToString(options),
+ const string expected_conditional =
R"(f32[5,20]{1,0} conditional(pred[], f32[5,10]{1,0}, f32[5,10]{1,0}), true_computation=
{
tmp_0 = f32[5,10]{1,0} parameter(0)
@@ -1579,7 +1580,8 @@ TEST_F(HloInstructionTest, CanonnicalStringificationConditional) {
tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
-})");
+})";
+ EXPECT_EQ(conditional->ToString(options), expected_conditional);
}
TEST_F(HloInstructionTest, CheckDeepClone) {
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index 7052e236cd..f333c489ed 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <deque>
+#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@@ -468,6 +469,46 @@ std::unique_ptr<HloInstruction> HloReduceInstruction::CloneWithNewOperandsImpl(
shape, new_operands[0], new_operands[1], dimensions(), to_apply());
}
+HloSortInstruction::HloSortInstruction(const Shape& shape, int64 dimension,
+ HloInstruction* keys,
+ HloInstruction* values)
+ : HloInstruction(HloOpcode::kSort, shape), dimensions_({dimension}) {
+ AppendOperand(keys);
+ if (values) {
+ AppendOperand(values);
+ }
+}
+
+HloInstructionProto HloSortInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ for (int64 dimension : dimensions_) {
+ proto.add_dimensions(dimension);
+ }
+ return proto;
+}
+
+std::vector<string> HloSortInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ return {StrCat("dimensions={", Join(dimensions(), ","), "}")};
+}
+
+bool HloSortInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other = static_cast<const HloSortInstruction&>(other);
+ return dimensions() == casted_other.dimensions();
+}
+
+std::unique_ptr<HloInstruction> HloSortInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ HloInstruction* keys = new_operands[0];
+ HloInstruction* values = new_operands.size() == 2 ? new_operands[1] : nullptr;
+ return MakeUnique<HloSortInstruction>(shape, dimensions(0), keys, values);
+}
+
HloTransposeInstruction::HloTransposeInstruction(
const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> dimensions)
@@ -766,7 +807,7 @@ string HloConstantInstruction::OperandsToStringWithCanonicalNameMap(
HloTraceInstruction::HloTraceInstruction(const string& tag,
HloInstruction* operand)
: HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil()),
- literal_(Literal::CreateR1U8(tag)) {
+ literal_(LiteralUtil::CreateR1U8(tag)) {
AppendOperand(operand);
operand->set_tracing(this);
}
@@ -1052,8 +1093,6 @@ HloInstruction* HloFusionInstruction::CloneAndFuseInternal(
CHECK_NOTNULL(GetModule())->AddEmbeddedComputation(builder.Build()));
clone = fused_expression_root();
} else {
- clone = fused_instructions_computation()->AddInstruction(
- instruction_to_fuse->Clone(/*suffix=*/""));
// When add_output is false, instruction_to_fuse is necessarily an operand
// of the fusion instruction. After fusion this will no longer be the
// case. Remove the operand from the operand list and remove its
@@ -1063,6 +1102,16 @@ HloInstruction* HloFusionInstruction::CloneAndFuseInternal(
bool in_operand_list = std::find(operands().begin(), operands().end(),
instruction_to_fuse) != operands().end();
CHECK(add_output || in_operand_list);
+ if (instruction_to_fuse->opcode() == HloOpcode::kTuple) {
+ // We assume all uses of a kTuple operation are GTE ops, not another
+ // fusion node. In this case, we don't need to clone
+ // 'instruction_to_fuse'.
+ CHECK(!in_operand_list);
+ clone = instruction_to_fuse;
+ } else {
+ clone = fused_instructions_computation()->AddInstruction(
+ instruction_to_fuse->Clone(/*suffix=*/""));
+ }
const std::vector<HloInstruction*>& fused_parameters =
fused_instructions_computation()->parameter_instructions();
for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) {
@@ -1159,9 +1208,10 @@ HloInstruction* HloFusionInstruction::CloneAndFuseInternal(
}
int64 index = tuple_elements.size();
if (instruction_to_fuse->opcode() == HloOpcode::kTuple) {
- index -= instruction_to_fuse->operand_count();
+ CHECK_EQ(clone, instruction_to_fuse);
+ index -= clone->operand_count();
std::vector<HloInstruction*> to_be_removed;
- for (auto old_gte : instruction_to_fuse->users()) {
+ for (auto old_gte : clone->users()) {
CHECK_EQ(old_gte->opcode(), HloOpcode::kGetTupleElement);
int64 old_tuple_index = old_gte->tuple_index();
HloInstruction* new_gte =
@@ -1173,7 +1223,6 @@ HloInstruction* HloFusionInstruction::CloneAndFuseInternal(
for (auto old_gte : to_be_removed) {
TF_CHECK_OK(parent()->RemoveInstruction(old_gte));
}
- TF_CHECK_OK(fused_instructions_computation()->RemoveInstruction(clone));
} else {
HloInstruction* new_gte =
parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
@@ -1182,7 +1231,9 @@ HloInstruction* HloFusionInstruction::CloneAndFuseInternal(
}
}
- VLOG(2) << "New clone:\n" << clone->ToString();
+ if (clone != instruction_to_fuse) {
+ VLOG(2) << "New clone:\n" << clone->ToString();
+ }
return clone;
}
@@ -1863,4 +1914,93 @@ HloDynamicSliceInstruction::CloneWithNewOperandsImpl(
return MakeUnique<HloDynamicSliceInstruction>(
shape, new_operands[0], new_operands[1], dynamic_slice_sizes_);
}
+
+HloGatherInstruction::HloGatherInstruction(
+ const Shape& shape, HloInstruction* operand, HloInstruction* gather_indices,
+ const GatherDimensionNumbers& gather_dim_numbers,
+ tensorflow::gtl::ArraySlice<int64> window_bounds)
+ : HloInstruction(HloOpcode::kGather, shape) {
+ AppendOperand(operand);
+ AppendOperand(gather_indices);
+ gather_dimension_numbers_ =
+ MakeUnique<GatherDimensionNumbers>(gather_dim_numbers);
+ c_copy(window_bounds, std::back_inserter(gather_window_bounds_));
+}
+
+string HloGatherInstruction::GatherDimensionNumbersToString() const {
+ CHECK(gather_dimension_numbers_ != nullptr);
+ string output_window_dims =
+ StrCat("output_window_dims={",
+ Join(gather_dimension_numbers_->output_window_dims(), ","), "}");
+ string elided_window_dims =
+ StrCat("elided_window_dims={",
+ Join(gather_dimension_numbers_->elided_window_dims(), ","), "}");
+ string gather_dims_to_operand_dims = StrCat(
+ "gather_dims_to_operand_dims={",
+ Join(gather_dimension_numbers_->gather_dims_to_operand_dims(), ","), "}");
+ string index_vector_dim = StrCat(
+ "index_vector_dim=", gather_dimension_numbers_->index_vector_dim());
+
+ return Join<std::initializer_list<string>>(
+ {output_window_dims, elided_window_dims, gather_dims_to_operand_dims,
+ index_vector_dim},
+ ", ");
+}
+
+/* static */ GatherDimensionNumbers HloGatherInstruction::MakeGatherDimNumbers(
+ tensorflow::gtl::ArraySlice<int64> output_window_dims,
+ tensorflow::gtl::ArraySlice<int64> elided_window_dims,
+ tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims,
+ int64 index_vector_dim) {
+ GatherDimensionNumbers gather_dim_numbers;
+ for (int64 output_window_dim : output_window_dims) {
+ gather_dim_numbers.add_output_window_dims(output_window_dim);
+ }
+ for (int64 elided_window_dim : elided_window_dims) {
+ gather_dim_numbers.add_elided_window_dims(elided_window_dim);
+ }
+ for (int64 gather_dim_to_input_dim : gather_dims_to_operand_dims) {
+ gather_dim_numbers.add_gather_dims_to_operand_dims(gather_dim_to_input_dim);
+ }
+
+ gather_dim_numbers.set_index_vector_dim(index_vector_dim);
+ return gather_dim_numbers;
+}
+
+HloInstructionProto HloGatherInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ *proto.mutable_gather_dimension_numbers() = gather_dimension_numbers();
+ for (int64 bound : gather_window_bounds()) {
+ proto.add_gather_window_bounds(bound);
+ }
+ return proto;
+}
+
+std::vector<string> HloGatherInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ return {GatherDimensionNumbersToString(),
+ StrCat("window_bounds={", Join(gather_window_bounds(), ","), "}")};
+}
+
+bool HloGatherInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other = static_cast<const HloGatherInstruction&>(other);
+ return protobuf_util::ProtobufEquals(
+ gather_dimension_numbers(),
+ casted_other.gather_dimension_numbers()) &&
+ gather_window_bounds() == casted_other.gather_window_bounds();
+}
+
+std::unique_ptr<HloInstruction> HloGatherInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 2);
+ return MakeUnique<HloGatherInstruction>(
+ shape, new_operands[0], new_operands[1], gather_dimension_numbers(),
+ gather_window_bounds());
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index df6969c410..65a93cdcf1 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -349,6 +349,35 @@ class HloReduceInstruction : public HloInstruction {
std::vector<int64> dimensions_;
};
+class HloSortInstruction : public HloInstruction {
+ public:
+ explicit HloSortInstruction(const Shape& shape, int64 dimension,
+ HloInstruction* keys,
+ HloInstruction* values = nullptr);
+ // Returns the dimension sizes or numbers associated with this instruction.
+ const std::vector<int64>& dimensions() const override { return dimensions_; }
+ int64 dimensions(int64 index) const override { return dimensions()[index]; }
+ // Returns the sort dimension for this instruction
+ int64 sort_dimension() { return dimensions(0); }
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+
+ std::vector<int64> dimensions_;
+};
+
class HloTransposeInstruction : public HloInstruction {
public:
explicit HloTransposeInstruction(
@@ -1119,6 +1148,49 @@ class HloDynamicSliceInstruction : public HloInstruction {
// ('start' is specified dynamically in the second operand of the operation).
std::vector<int64> dynamic_slice_sizes_;
};
+
+class HloGatherInstruction : public HloInstruction {
+ public:
+ explicit HloGatherInstruction(
+ const Shape& shape, HloInstruction* operand,
+ HloInstruction* gather_indices,
+ const GatherDimensionNumbers& gather_dim_numbers,
+ tensorflow::gtl::ArraySlice<int64> window_bounds);
+ const GatherDimensionNumbers& gather_dimension_numbers() const {
+ CHECK(gather_dimension_numbers_ != nullptr);
+ return *gather_dimension_numbers_;
+ }
+ tensorflow::gtl::ArraySlice<int64> gather_window_bounds() const {
+ return gather_window_bounds_;
+ }
+ // Returns the dump string of the gather dimension numbers.
+ string GatherDimensionNumbersToString() const;
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ // Creates an instance of GatherDimensionNumbers.
+ static GatherDimensionNumbers MakeGatherDimNumbers(
+ tensorflow::gtl::ArraySlice<int64> output_window_dims,
+ tensorflow::gtl::ArraySlice<int64> elided_window_dims,
+ tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims,
+ int64 index_vector_dim);
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+
+ std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_;
+ std::vector<int64> gather_window_bounds_;
+};
+
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc
index 0275294a1a..01b625c29c 100644
--- a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_liveness_analysis.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
diff --git a/tensorflow/compiler/xla/service/hlo_matchers_test.cc b/tensorflow/compiler/xla/service/hlo_matchers_test.cc
index 9a3010cf1f..7de59acc1e 100644
--- a/tensorflow/compiler/xla/service/hlo_matchers_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_matchers_test.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -75,8 +76,10 @@ TEST(HloMatchersTest, Test) {
}
TEST(HloMatchersTest, CustomCallMatcher) {
- auto c1 = HloInstruction::CreateConstant(Literal::CreateR1<float>({1, 2, 3}));
- auto c2 = HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 3}));
+ auto c1 =
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({1, 2, 3}));
+ auto c2 =
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({1, 2, 3}));
auto call = HloInstruction::CreateCustomCall(
ShapeUtil::MakeShape(F32, {1}), {c1.get(), c2.get()}, "foo_target");
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index 39bc25ba42..55ff073d3f 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -537,10 +537,11 @@ uint64 HloModule::RandomNew64() const {
HloComputation* HloModule::GetComputationWithName(
tensorflow::StringPiece name) {
- auto it = c_find_if(computations(), [&](HloComputation* computation) {
+ auto computations_in_module = computations();
+ auto it = c_find_if(computations_in_module, [&](HloComputation* computation) {
return computation->name() == name;
});
- return it == computations().end() ? nullptr : *it;
+ return it == computations_in_module.end() ? nullptr : *it;
}
/* static */ std::atomic<int> HloModule::next_unique_module_id_(0);
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc
index 21a9b7291a..df1d562048 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc
@@ -292,7 +292,7 @@ HloModuleGroupUtil::ComputeReachability(
}
auto reachability = MakeUnique<HloReachabilityMap>(post_order);
for (HloInstruction* hlo : post_order) {
- reachability->SetReachabilityToUnion(GlobalPredecessors(hlo), hlo);
+ reachability->FastSetReachabilityToUnion(GlobalPredecessors(hlo), hlo);
}
return std::move(reachability);
}
diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc
index 7f28a804bf..236f450086 100644
--- a/tensorflow/compiler/xla/service/hlo_module_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_test.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -38,7 +38,7 @@ class HloModuleTest : public HloTestBase {
std::unique_ptr<HloComputation> CreateConstantComputation() {
auto builder = HloComputation::Builder("Constant");
builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
return builder.Build();
}
@@ -122,7 +122,7 @@ TEST_F(HloModuleTest, CloneHasFusion) {
{
auto b = HloComputation::Builder("Entry");
auto input = b.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
b.AddInstruction(
HloInstruction::CreateFusion(r0f32_, HloInstruction::FusionKind::kInput,
/*operands=*/{input}, fused_computation));
@@ -173,7 +173,7 @@ TEST_F(HloModuleTest, LargeConstantToString) {
auto builder = HloComputation::Builder("Constant");
std::vector<float> values(16, 42.0);
builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>(values)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(values)));
module->AddEntryComputation(builder.Build());
EXPECT_EQ(
diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
index cfe5dace05..126d3a2d9c 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
@@ -57,7 +57,7 @@ TEST_F(HloOrderingTest, InstructionsInDifferentComputations) {
auto builder_c = HloComputation::Builder("C");
HloInstruction* c = builder_c.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
HloComputation* computation_c =
module->AddEmbeddedComputation(builder_c.Build());
@@ -145,7 +145,7 @@ TEST_F(HloOrderingTest, InstructionsInWhileComputations) {
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto xla_while = builder.AddInstruction(
HloInstruction::CreateWhile(scalar_shape, condition, body, constant));
module->AddEntryComputation(builder.Build());
@@ -208,7 +208,7 @@ TEST_F(HloOrderingTest, ValuesInWhileComputations) {
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto xla_while = builder.AddInstruction(
HloInstruction::CreateWhile(scalar_shape, condition, body, constant));
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index f192debc9c..d387539350 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -15,8 +15,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -621,23 +623,32 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
- instruction =
- builder->AddInstruction(HloInstruction::CreateAfterAll(operands));
+ if (operands.empty()) {
+ instruction = builder->AddInstruction(HloInstruction::CreateToken());
+ } else {
+ instruction =
+ builder->AddInstruction(HloInstruction::CreateAfterAll(operands));
+ }
break;
}
case HloOpcode::kSort: {
auto loc = lexer_.GetLoc();
- if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
+
+ optional<std::vector<tensorflow::int64>> dimensions;
+ attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
+ &dimensions};
+ if (!ParseOperands(&operands) || !ParseAttributes(attrs) ||
+ dimensions->size() != 1) {
return false;
}
switch (operands.size()) {
case 1:
- instruction = builder->AddInstruction(
- HloInstruction::CreateSort(shape, /*keys=*/operands[0]));
+ instruction = builder->AddInstruction(HloInstruction::CreateSort(
+ shape, dimensions->at(0), /*keys=*/operands[0]));
break;
case 2:
instruction = builder->AddInstruction(HloInstruction::CreateSort(
- shape,
+ shape, dimensions->at(0),
/*keys=*/operands[0], /*values=*/operands[1]));
break;
default:
@@ -1182,11 +1193,12 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
return false;
}
- GatherDimensionNumbers dim_numbers = HloInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/*output_window_dims,
- /*elided_window_dims=*/*elided_window_dims,
- /*gather_dims_to_operand_dims=*/*gather_dims_to_operand_dims,
- /*index_vector_dim=*/*index_vector_dim);
+ GatherDimensionNumbers dim_numbers =
+ HloGatherInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/*output_window_dims,
+ /*elided_window_dims=*/*elided_window_dims,
+ /*gather_dims_to_operand_dims=*/*gather_dims_to_operand_dims,
+ /*index_vector_dim=*/*index_vector_dim);
instruction = builder->AddInstruction(HloInstruction::CreateGather(
shape, /*operand=*/operands[0], /*gather_indices=*/operands[1],
@@ -1609,7 +1621,7 @@ bool HloParser::ParseTupleLiteral(std::unique_ptr<Literal>* literal,
}
}
}
- *literal = Literal::MakeTupleOwned(std::move(elements));
+ *literal = LiteralUtil::MakeTupleOwned(std::move(elements));
return ParseToken(TokKind::kRparen,
StrCat("expects ')' at the end of the tuple with ",
ShapeUtil::TupleElementCount(shape), "elements"));
@@ -1637,8 +1649,8 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
}
// Create a literal with the given shape in default layout.
- *literal = Literal::CreateFromDimensions(shape.element_type(),
- AsInt64Slice(shape.dimensions()));
+ *literal = LiteralUtil::CreateFromDimensions(
+ shape.element_type(), AsInt64Slice(shape.dimensions()));
tensorflow::int64 nest_level = 0;
tensorflow::int64 linear_index = 0;
// elems_seen_per_dim[i] is how many elements or sub-arrays we have seen for
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index 88f3309baa..f06c705c42 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -840,7 +840,7 @@ R"(HloModule sort
ENTRY Sort {
x = f32[1024]{0} parameter(0)
- ROOT sorted = f32[1024]{0} sort(x)
+ ROOT sorted = f32[1024]{0} sort(x), dimensions={0}
}
)"
@@ -853,7 +853,32 @@ R"(HloModule sort
ENTRY Sort {
keys = f32[1024]{0} parameter(0)
values = s32[1024]{0} parameter(1)
- ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values)
+ ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values), dimensions={0}
+}
+
+)"
+},
+// R2 Sort (Key)
+{
+"SortKeyR2",
+R"(HloModule sort
+
+ENTRY Sort {
+ x = f32[1024,16]{0,1} parameter(0)
+ ROOT sorted = f32[1024,16]{0,1} sort(x), dimensions={0}
+}
+
+)"
+},
+// R2 Sort (Key, Value)
+{
+"SortKeyValueR2",
+R"(HloModule sort
+
+ENTRY Sort {
+ keys = f32[1024,16]{0,1} parameter(0)
+ values = s32[1024,16]{0,1} parameter(1)
+ ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}) sort(keys, values), dimensions={0}
}
)"
diff --git a/tensorflow/compiler/xla/service/hlo_query.cc b/tensorflow/compiler/xla/service/hlo_query.cc
index 2418c19f3d..2a07b6fcbc 100644
--- a/tensorflow/compiler/xla/service/hlo_query.cc
+++ b/tensorflow/compiler/xla/service/hlo_query.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_query.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
diff --git a/tensorflow/compiler/xla/service/hlo_reachability_test.cc b/tensorflow/compiler/xla/service/hlo_reachability_test.cc
index 657a9ee83d..585c95972b 100644
--- a/tensorflow/compiler/xla/service/hlo_reachability_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_reachability_test.cc
@@ -39,15 +39,15 @@ TEST_F(HloReachabilityTest, Reachability) {
*/
auto builder = HloComputation::Builder(TestName());
auto a = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
auto b = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
auto c = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
auto d = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
auto e = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
builder.Build();
HloReachabilityMap reachability({a, b, c, d, e});
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
index 0b222f4348..59a8800a7d 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
@@ -1202,17 +1202,14 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
StatusOr<bool> HloRematerialization::Run(
HloModule* module, SequentialHloOrdering::HloModuleSequence* sequence,
- int64 memory_limit_bytes, RematerializationSizes* sizes) {
+ int64 memory_limit_bytes, RematerializationSizes* sizes,
+ bool run_copy_elision) {
// The sequence is constructed entirely by this method.
TF_RET_CHECK(sequence->empty());
VLOG(1) << "HloRematerialization() with memory limit of "
<< HumanReadableNumBytes(memory_limit_bytes);
- if (copy_insertion_) {
- TF_RETURN_IF_ERROR(copy_insertion_->Run(module).status());
- }
-
TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module));
// Adjust memory limit to account for the output of the entry
@@ -1241,12 +1238,13 @@ StatusOr<bool> HloRematerialization::Run(
return size_function_(buffer.shape());
},
scheduler_algorithm_));
- if (copy_insertion_) {
+ if (run_copy_elision) {
// We run a separate pass of copy elision here because the sequential
// ordering from the HLO schedule allows for more copies to be eliminated.
+ // TODO(b/80249101): Instead of a separate copy elision pass, use the
+ // ordering from the HLO schedule directly for copy insertion.
SequentialHloOrdering ordering(module, *sequence);
- TF_RETURN_IF_ERROR(
- copy_insertion_->RemoveUnnecessaryCopies(ordering, module));
+ TF_RETURN_IF_ERROR(RemoveUnnecessaryCopies(ordering, module));
}
// Compute peak memory usage of all computations in the module called in a
@@ -1351,10 +1349,10 @@ StatusOr<bool> HloRematerialization::Run(
int64 memory_limit_bytes, HloModule* hlo_module,
MemorySchedulerAlgorithm scheduler_algorithm,
SequentialHloOrdering::HloModuleSequence* sequence,
- RematerializationSizes* sizes, CopyInsertion* copy_insertion) {
- HloRematerialization remat(std::move(scheduler_algorithm), size_function,
- copy_insertion);
- return remat.Run(hlo_module, sequence, memory_limit_bytes, sizes);
+ RematerializationSizes* sizes, bool run_copy_elision) {
+ HloRematerialization remat(scheduler_algorithm, size_function);
+ return remat.Run(hlo_module, sequence, memory_limit_bytes, sizes,
+ run_copy_elision);
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h
index 1c72f42b8c..59b4cf5dcc 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.h
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h
@@ -17,7 +17,6 @@
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
#include "tensorflow/compiler/xla/service/call_graph.h"
-#include "tensorflow/compiler/xla/service/copy_insertion.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@@ -58,8 +57,11 @@ class HloRematerialization {
// sizes: Optional outparam that indicates the peak memory usage of the HLO
// module before/after rematerialization.
//
- // copy_insertion: If non-null, run the provided copy insertion pass
- // before HLO scheduling.
+ // run_copy_elision: Enable copy elision. This pass is used to eliminate
+ // copies that were inserted before HLO scheduling.
+ //
+ // TODO(b/80249101): Remove the 'run_copy_elision' parameter when copy
+ // insertion is integrated with HLO scheduling.
//
// Returns whether any instructions were rematerialized. If memory use is
// already below the given limit then no instructions are rematerialized and
@@ -72,15 +74,13 @@ class HloRematerialization {
const ShapeSizeFunction& size_function, int64 memory_limit_bytes,
HloModule* hlo_module, MemorySchedulerAlgorithm scheduler_algorithm,
SequentialHloOrdering::HloModuleSequence* sequence,
- RematerializationSizes* sizes, CopyInsertion* copy_insertion = nullptr);
+ RematerializationSizes* sizes, bool run_copy_elision = true);
protected:
HloRematerialization(MemorySchedulerAlgorithm scheduler_algorithm,
- const ShapeSizeFunction& size_function,
- CopyInsertion* copy_insertion)
+ const ShapeSizeFunction& size_function)
: scheduler_algorithm_(scheduler_algorithm),
- size_function_(size_function),
- copy_insertion_(copy_insertion) {}
+ size_function_(size_function) {}
~HloRematerialization() {}
// Runs rematerialization on the given module. Returns whether the module was
@@ -89,7 +89,8 @@ class HloRematerialization {
// contains the memory-minimizing order in which to emit the HLO instructions.
StatusOr<bool> Run(HloModule* module,
SequentialHloOrdering::HloModuleSequence* sequence,
- int64 memory_limit, RematerializationSizes* sizes);
+ int64 memory_limit, RematerializationSizes* sizes,
+ bool run_copy_elision);
// Rematerializes instructions within the given computation. 'order' is the
// order in which the computation's instructions will be emitted in the
@@ -145,9 +146,6 @@ class HloRematerialization {
// uses of the original instruction and the original instruction is
// dead. Hence, no net instructions were added.
int64 net_instructions_added_ = 0;
-
- // Copy insertion pass that runs before HLO scheduling.
- CopyInsertion* copy_insertion_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
index fc137c839f..cd131147e6 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
@@ -132,7 +132,7 @@ class HloRematerializationTest : public HloTestBase {
builder.AddInstruction(
HloInstruction::CreateParameter(0, vec1_shape_, "param"));
builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
return builder.Build();
}
@@ -143,12 +143,11 @@ class HloRematerializationTest : public HloTestBase {
StatusOr<bool> RunHloRematerialization(
int64 memory_limit_bytes, HloModule* module,
- SequentialHloOrdering::HloModuleSequence* sequence,
- CopyInsertion* copy_insertion = nullptr) {
+ SequentialHloOrdering::HloModuleSequence* sequence) {
TF_EXPECT_OK(verifier().Run(module).status());
return HloRematerialization::RematerializeAndSchedule(
ByteSizeOf, memory_limit_bytes, module, DefaultMemoryScheduler,
- sequence, /*sizes=*/nullptr, copy_insertion);
+ sequence, /*sizes=*/nullptr, /*run_copy_elision=*/false);
}
// Various shapes used in the canned computations.
@@ -227,7 +226,7 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) {
cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, vec1_shape_, "param"));
cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
HloComputation* while_cond =
module->AddEmbeddedComputation(cond_builder.Build());
@@ -264,7 +263,7 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) {
cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, vec1_shape_, "param"));
cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
HloComputation* while_cond =
module->AddEmbeddedComputation(cond_builder.Build());
@@ -288,41 +287,6 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) {
EXPECT_EQ(body_computation->instruction_count(), 9);
}
-// Similar to RematerializeEntryAndWhileBody, except with copy insertion run
-// after HLO scheduling.
-TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBodyWithCopies) {
- auto module = CreateNewModule();
-
- auto cond_builder = HloComputation::Builder(TestName() + ".cond");
- cond_builder.AddInstruction(
- HloInstruction::CreateParameter(0, vec1_shape_, "param"));
- cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
- HloComputation* while_cond =
- module->AddEmbeddedComputation(cond_builder.Build());
-
- HloComputation* body_computation = module->AddEmbeddedComputation(
- MakeRematerializableComputation(/*suffix=*/".body"));
- HloComputation* entry_computation =
- module->AddEntryComputation(MakeRematerializableWhileComputation(
- while_cond, /*while_body=*/body_computation));
-
- EXPECT_EQ(entry_computation->instruction_count(), 7);
- EXPECT_EQ(body_computation->instruction_count(), 8);
-
- SequentialHloOrdering::HloModuleSequence sequence;
- CopyInsertion copy_insertion;
- TF_ASSERT_OK_AND_ASSIGN(bool changed,
- RunHloRematerialization(
- /*memory_limit_bytes=*/15 * 1024, module.get(),
- &sequence, &copy_insertion));
- EXPECT_TRUE(changed);
-
- // Both computations should have rematerialized instructions added.
- EXPECT_EQ(entry_computation->instruction_count(), 9);
- EXPECT_EQ(body_computation->instruction_count(), 9);
-}
-
// Test rematerialization of a doubly nested computation. All computations
// should have an instruction rematerialized.
TEST_F(HloRematerializationTest, RematerializeNestedComputations) {
@@ -332,7 +296,7 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) {
cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, vec1_shape_, "param"));
cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
HloComputation* while_cond =
module->AddEmbeddedComputation(cond_builder.Build());
diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc
index b2725e2918..4f0569f405 100644
--- a/tensorflow/compiler/xla/service/hlo_runner.cc
+++ b/tensorflow/compiler/xla/service/hlo_runner.cc
@@ -180,12 +180,8 @@ StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
CreateExecutable(std::move(module), run_hlo_passes));
- TF_ASSIGN_OR_RETURN(
- ScopedShapedBuffer retval,
- executable->ExecuteOnStreamWrapper(&service_run_options,
- /*profile=*/profile, arguments));
- TF_RETURN_IF_ERROR(stream.BlockHostUntilDone());
- return std::move(retval);
+ return executable->ExecuteOnStreamWrapper(&service_run_options,
+ /*profile=*/profile, arguments);
}
StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
@@ -313,7 +309,6 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated(
std::vector<std::unique_ptr<Literal>> exec_results;
for (int64 i = 0; i < options.num_replicas; ++i) {
- TF_RETURN_IF_ERROR(streams[i]->BlockHostUntilDone());
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
backend().transfer_manager()->TransferLiteralFromDevice(
streams[i].get(), results[i]));
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
index 73f22f81f4..cf9ceed5b2 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
@@ -168,8 +168,9 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) {
auto cond_builder = HloComputation::Builder("WhileCond");
HloInstruction* cond_param = cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, r1f32, "cond_param"));
- HloInstruction* zero_vector = cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2<float>({{0, 0, 0, 0}})));
+ HloInstruction* zero_vector =
+ cond_builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{0, 0, 0, 0}})));
cond_builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, cond_param, zero_vector));
auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build());
@@ -179,16 +180,18 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) {
auto body_builder = HloComputation::Builder("WhileBody");
HloInstruction* body_param = body_builder.AddInstruction(
HloInstruction::CreateParameter(0, r1f32, "body_param"));
- HloInstruction* one_vector = body_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2<float>({{1, 1, 1, 1}})));
+ HloInstruction* one_vector =
+ body_builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{1, 1, 1, 1}})));
body_builder.AddInstruction(HloInstruction::CreateBinary(
r1f32, HloOpcode::kSubtract, body_param, one_vector));
auto body_computation = module->AddEmbeddedComputation(body_builder.Build());
// transpose(matrix) + bcast(while)
auto builder = HloComputation::Builder(TestName());
- HloInstruction* while_init = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2<float>({{1, 1, 1, 1}})));
+ HloInstruction* while_init =
+ builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{1, 1, 1, 1}})));
// Creates 16 bytes, ignoring subcomputations
HloInstruction* while_loop =
builder.AddInstruction(HloInstruction::CreateWhile(
@@ -199,7 +202,7 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) {
HloInstruction::CreateBroadcast(r2f32, while_loop, {0}));
HloInstruction* matrix = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2<float>(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>(
{{1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}})));
// Creates 32 bytes
HloInstruction* transpose = builder.AddInstruction(
@@ -257,7 +260,7 @@ TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) {
// Wrap lit in abs because constants are considered free by
// IgnoreInstruction, and it skews the accounting.
auto lit = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({1, 1, 1, 1, 1, 1})));
+ LiteralUtil::CreateR1<float>({1, 1, 1, 1, 1, 1})));
auto abs_const = builder.AddInstruction(
HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, lit));
@@ -300,11 +303,11 @@ TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) {
HloComputation::Builder builder(TestName());
auto c1 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({1, 1, 1, 1, 1})));
+ LiteralUtil::CreateR1<float>({1, 1, 1, 1, 1})));
auto c2 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({1, 2, 3, 4, 5})));
+ LiteralUtil::CreateR1<float>({1, 2, 3, 4, 5})));
auto c3 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({0, 2, 4, 6, 8})));
+ LiteralUtil::CreateR1<float>({0, 2, 4, 6, 8})));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, c1, c2));
@@ -354,8 +357,9 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) {
auto cond_builder = HloComputation::Builder("WhileCond");
HloInstruction* cond_param = cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, r1f32, "cond_param"));
- HloInstruction* zero_vector = cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2<float>({{0, 0, 0, 0}})));
+ HloInstruction* zero_vector =
+ cond_builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{0, 0, 0, 0}})));
cond_builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, cond_param, zero_vector));
auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build());
@@ -365,15 +369,17 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) {
auto body_builder = HloComputation::Builder("WhileBody");
HloInstruction* body_param = body_builder.AddInstruction(
HloInstruction::CreateParameter(0, r1f32, "body_param"));
- HloInstruction* one_vector = body_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2<float>({{1, 1, 1, 1}})));
+ HloInstruction* one_vector =
+ body_builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{1, 1, 1, 1}})));
body_builder.AddInstruction(HloInstruction::CreateBinary(
r1f32, HloOpcode::kSubtract, body_param, one_vector));
auto body_computation = module->AddEmbeddedComputation(body_builder.Build());
auto builder = HloComputation::Builder(TestName());
- HloInstruction* while_init = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2<float>({{1, 1, 1, 1}})));
+ HloInstruction* while_init =
+ builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{1, 1, 1, 1}})));
// Creates 16 bytes, ignoring subcomputations
builder.AddInstruction(HloInstruction::CreateWhile(
r1f32, cond_computation, body_computation, while_init));
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc
index 268b4727bc..393944c20f 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding.cc
@@ -60,6 +60,9 @@ HloSharding HloSharding::Tuple(
const Shape& tuple_shape,
tensorflow::gtl::ArraySlice<HloSharding> shardings) {
CHECK(ShapeUtil::IsTuple(tuple_shape)) << ShapeUtil::HumanString(tuple_shape);
+ for (auto& sharding : shardings) {
+ CHECK(!sharding.IsTuple()) << sharding.ToString();
+ }
std::vector<HloSharding> flattened_list(shardings.begin(), shardings.end());
CHECK_EQ(flattened_list.size(), RequiredLeaves(tuple_shape))
<< "Flat list has " << flattened_list.size() << ", required "
@@ -67,6 +70,24 @@ HloSharding HloSharding::Tuple(
return HloSharding(flattened_list);
}
+HloSharding HloSharding::SingleTuple(const Shape& tuple_shape,
+ const HloSharding& sharding) {
+ CHECK(ShapeUtil::IsTuple(tuple_shape)) << ShapeUtil::HumanString(tuple_shape);
+ CHECK(!sharding.IsTuple()) << sharding.ToString();
+ int64 leaf_count = ShapeUtil::GetLeafCount(tuple_shape);
+ std::vector<HloSharding> flattened_list;
+ flattened_list.reserve(leaf_count);
+ for (int64 i = 0; i < leaf_count; ++i) {
+ flattened_list.push_back(sharding);
+ }
+ return HloSharding(flattened_list);
+}
+
+HloSharding HloSharding::Single(const Shape& shape,
+ const HloSharding& sharding) {
+ return ShapeUtil::IsTuple(shape) ? SingleTuple(shape, sharding) : sharding;
+}
+
string HloSharding::ToString() const {
if (IsTuple()) {
std::vector<string> parts;
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h
index 34324d2058..6f672b0f28 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.h
+++ b/tensorflow/compiler/xla/service/hlo_sharding.h
@@ -24,7 +24,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/array.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/shape_tree.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -80,6 +80,15 @@ class HloSharding {
static HloSharding Tuple(const Shape& tuple_shape,
tensorflow::gtl::ArraySlice<HloSharding> shardings);
+ // Creates a new sharding for a tuple type, with a single input sharding
+ // repeated on each leaf.
+ static HloSharding SingleTuple(const Shape& tuple_shape,
+ const HloSharding& sharding);
+
+ // If shape is an array, returns sharding, otherwise returns the tuple shaped
+ // sharding with all the leaf nodes having the same input sharding.
+ static HloSharding Single(const Shape& shape, const HloSharding& sharding);
+
// Create a new sharding from a protobuf OpSharding.
static StatusOr<HloSharding> FromProto(const OpSharding& proto);
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
index 39036e205e..4f91d619ef 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
@@ -88,6 +88,12 @@ std::vector<PassThrough> LocatePassThroughDomainLinks(
VLOG(2) << " " << instruction->ToString();
}
}
+ if (instruction == instruction->parent()->root_instruction()) {
+ pass_through.emplace_back(nullptr, instruction);
+ VLOG(2) << "Found passthrough domain link:";
+ VLOG(2) << " <root>";
+ VLOG(2) << " " << instruction->ToString();
+ }
}
return pass_through;
}
@@ -101,8 +107,12 @@ Status FixupPassThroughDomainLinks(const DomainMetadata::Domain& domain,
HloInstruction::CreateGetTupleElement(pass_through.operand->shape(),
tuple, 0));
gte->set_sharding(sharding);
- TF_RETURN_IF_ERROR(
- pass_through.operand->ReplaceUseWith(pass_through.user, gte));
+ if (pass_through.user != nullptr) {
+ TF_RETURN_IF_ERROR(
+ pass_through.operand->ReplaceUseWith(pass_through.user, gte));
+ } else {
+ pass_through.operand->parent()->set_root_instruction(gte);
+ }
}
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc
index 54b7402b86..7baa927d0e 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include <utility>
#include <vector>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
diff --git a/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc b/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc
index 7b601f9a95..45c684d667 100644
--- a/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc
@@ -75,7 +75,7 @@ TEST_F(HloSubcomputationUnificationTest, UnifyIdentities) {
module->AddEmbeddedComputation(CreateR0S32IdentityComputation());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(5)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(5)));
auto x = builder.AddInstruction(
HloInstruction::CreateCall(r0s32_, {constant}, callee1));
auto y = builder.AddInstruction(
@@ -112,9 +112,9 @@ TEST_F(HloSubcomputationUnificationTest, UnifyAdditions) {
module->AddEmbeddedComputation(CreateR0S32AdditionComputation());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(5)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(5)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(3)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(3)));
auto x = builder.AddInstruction(
HloInstruction::CreateCall(r0s32_, {constant1, constant2}, callee1));
auto y = builder.AddInstruction(
diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc
index 3dc733940f..48f676db85 100644
--- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc
+++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h"
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/framework/attr_value.pb.h"
diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc
index be156d765d..1e2b31a1f2 100644
--- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc
@@ -90,7 +90,7 @@ TEST_F(HloTfGraphBuilderTest, CheckConcatenateDimsAndShapes) {
TEST_F(HloTfGraphBuilderTest, CheckScalarValue) {
auto builder = HloComputation::Builder("Const");
HloInstruction *instruction = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(123)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0(123)));
OpMetadata metadata;
metadata.set_op_name("x");
metadata.set_op_type("y");
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index f896773729..48eeba6afd 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -127,6 +127,22 @@ Status CheckIsTokenOperand(const HloInstruction* instruction,
return Status::OK();
}
+Status CheckOperandAndParameter(const HloInstruction* instruction,
+ int64 operand_number,
+ const HloComputation* computation,
+ int64 parameter_number) {
+ const HloInstruction* operand = instruction->operand(operand_number);
+ const HloInstruction* parameter =
+ computation->parameter_instruction(parameter_number);
+ if (!ShapeUtil::Compatible(operand->shape(), parameter->shape())) {
+ return InternalError("Operand %s shape does not match parameter's %s in %s",
+ operand->ToString().c_str(),
+ parameter->ToString().c_str(),
+ instruction->ToString().c_str());
+ }
+ return Status::OK();
+}
+
} // namespace
Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) {
@@ -253,8 +269,11 @@ Status ShapeVerifier::HandleParameter(HloInstruction* hlo) {
Status ShapeVerifier::HandleFusion(HloInstruction*) { return Status::OK(); }
Status ShapeVerifier::HandleCall(HloInstruction* call) {
+ for (int64 i = 0; i < call->to_apply()->num_parameters(); ++i) {
+ TF_RETURN_IF_ERROR(CheckOperandAndParameter(call, i, call->to_apply(), i));
+ }
// The shape of kCall should match the shape of the computation it calls.
- return CheckShape(call, call->to_apply()->ComputeProgramShape().result());
+ return CheckShape(call, call->to_apply()->root_instruction()->shape());
}
Status ShapeVerifier::HandleCustomCall(HloInstruction*) { return Status::OK(); }
@@ -323,19 +342,37 @@ Status ShapeVerifier::HandleSelectAndScatter(HloInstruction* instruction) {
}
Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) {
+ TF_RETURN_IF_ERROR(
+ CheckOperandAndParameter(xla_while, 0, xla_while->while_body(), 0));
+ TF_RETURN_IF_ERROR(
+ CheckOperandAndParameter(xla_while, 0, xla_while->while_condition(), 0));
+ const Shape& conditional_shape =
+ xla_while->while_condition()->root_instruction()->shape();
+ if (!ShapeUtil::Compatible(conditional_shape,
+ ShapeUtil::MakeShape(PRED, {}))) {
+ return InternalError(
+ "Conditional computation shape does not lead to a scalar predicate "
+ "shape: %s",
+ ShapeUtil::HumanString(conditional_shape).c_str());
+ }
// The shape of kWhile should match the shape of the body computation it
// calls.
return CheckShape(xla_while,
- xla_while->while_body()->ComputeProgramShape().result());
+ xla_while->while_body()->root_instruction()->shape());
}
Status ShapeVerifier::HandleConditional(HloInstruction* conditional) {
+ TF_RETURN_IF_ERROR(CheckOperandAndParameter(
+ conditional, 1, conditional->true_computation(), 0));
+ TF_RETURN_IF_ERROR(CheckOperandAndParameter(
+ conditional, 2, conditional->false_computation(), 0));
+ TF_RETURN_IF_ERROR(
+ CheckShape(conditional,
+ conditional->true_computation()->root_instruction()->shape()));
TF_RETURN_IF_ERROR(CheckShape(
conditional,
- conditional->true_computation()->ComputeProgramShape().result()));
- return CheckShape(
- conditional,
- conditional->false_computation()->ComputeProgramShape().result());
+ conditional->false_computation()->root_instruction()->shape()));
+ return Status::OK();
}
Status ShapeVerifier::HandlePad(HloInstruction* pad) {
@@ -802,33 +839,23 @@ Status HloVerifier::CheckWhileInstruction(HloInstruction* instruction) {
"While loop must have exactly one operand; had %lld : %s",
instruction->operand_count(), instruction->ToString().c_str());
}
- auto* init = instruction->operand(0);
- auto* cond_param = while_cond->parameter_instruction(0);
- if (!ShapeUtil::Compatible(init->shape(), cond_param->shape())) {
- return FailedPrecondition(
- "While condition's parameter must have the same shape as the "
- "loop's 'init'. init: %s, param: %s",
- init->ToString().c_str(), cond_param->ToString().c_str());
- }
- auto* cond_root = while_cond->root_instruction();
- if (!ShapeUtil::Compatible(cond_root->shape(),
- ShapeUtil::MakeShape(PRED, {}))) {
- return FailedPrecondition("While condition should have shape PRED: %s",
- cond_root->ToString().c_str());
- }
- auto* body_param = while_body->parameter_instruction(0);
- if (!ShapeUtil::Compatible(init->shape(), body_param->shape())) {
+ return Status::OK();
+}
+
+Status HloVerifier::CheckConditionalInstruction(HloInstruction* instruction) {
+ if (instruction->true_computation()->num_parameters() != 1) {
return FailedPrecondition(
- "While body's parameter must have the same shape as the loop's"
- " 'init'. init: %s, param: %s",
- init->ToString().c_str(), body_param->ToString().c_str());
+ "True computation %s of %s must have 1 parameter insted of %lld",
+ instruction->true_computation()->name().c_str(),
+ instruction->ToString().c_str(),
+ instruction->true_computation()->num_parameters());
}
- auto* body_root = while_body->root_instruction();
- if (!ShapeUtil::Compatible(init->shape(), body_root->shape())) {
+ if (instruction->false_computation()->num_parameters() != 1) {
return FailedPrecondition(
- "While body should have same shape as the loop's 'init'."
- "init: %s, body: %s",
- init->ToString().c_str(), body_root->ToString().c_str());
+ "False computation %s of %s must have 1 parameter insted of %lld",
+ instruction->false_computation()->name().c_str(),
+ instruction->ToString().c_str(),
+ instruction->false_computation()->num_parameters());
}
return Status::OK();
}
@@ -924,6 +951,8 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) {
<< " != " << ShapeUtil::Rank(instruction->operand(0)->shape());
} else if (instruction->opcode() == HloOpcode::kWhile) {
TF_RETURN_IF_ERROR(CheckWhileInstruction(instruction));
+ } else if (instruction->opcode() == HloOpcode::kConditional) {
+ TF_RETURN_IF_ERROR(CheckConditionalInstruction(instruction));
} else if (instruction->opcode() !=
HloOpcode::kRng /* Rng operands are always scalar. */
&& instruction->IsElementwise()) {
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h
index 12c047850e..9e62bdc8a9 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.h
+++ b/tensorflow/compiler/xla/service/hlo_verifier.h
@@ -146,6 +146,8 @@ class HloVerifier : public HloPassInterface {
Status CheckWhileInstruction(HloInstruction* instruction);
+ Status CheckConditionalInstruction(HloInstruction* instruction);
+
// Checks that the non-scalar operand shapes are compatible to the output
// shape, i.e., that there are no implicit broadcasts of size-one dimensions.
Status CheckElementwiseInstruction(HloInstruction* instruction);
diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
index c92db0be14..04c6ba3eeb 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
@@ -123,5 +124,55 @@ TEST_F(HloVerifierTest, ResetsShapeVerifierState) {
EXPECT_FALSE(verifier().Run(module.get()).status().ok());
}
+TEST_F(HloVerifierTest, CheckCallOperandParameterShapesMismatch) {
+ const char* const hlo_string = R"(
+HloModule Module
+
+callme {
+ ROOT param = (s32[], f32[4]) parameter(0)
+}
+
+ENTRY entry {
+ p0 = (f32[4], s32[]) parameter(0)
+ ROOT mycall = (s32[], f32[4]) call(p0), to_apply=callme
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string));
+
+ auto status = verifier().Run(module.get()).status();
+ ASSERT_FALSE(status.ok());
+ EXPECT_THAT(status.error_message(),
+ HasSubstr("shape does not match parameter"));
+}
+
+TEST_F(HloVerifierTest, CheckConditionalOperandParameterShapesMismatch) {
+ const char* const hlo_string = R"(
+HloModule Module
+
+true_branch {
+ tparam = (s32[], f32[4]) parameter(0)
+ ROOT tgte1 = f32[4] get-tuple-element(tparam), index=1
+}
+
+false_branch {
+ fparam = (s32[], f32[4]) parameter(0)
+ ROOT fgte1 = f32[4] get-tuple-element(fparam), index=1
+}
+
+ENTRY entry {
+ p0 = (f32[4], s32[]) parameter(0)
+ constant = pred[] constant(true)
+ ROOT conditional = f32[4] conditional(constant, p0, p0),
+ true_computation=true_branch, false_computation=false_branch
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string));
+
+ auto status = verifier().Run(module.get()).status();
+ ASSERT_FALSE(status.ok());
+ EXPECT_THAT(status.error_message(),
+ HasSubstr("shape does not match parameter"));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc b/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc
index 8c7b38dd1b..f85d31d522 100644
--- a/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc
+++ b/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/implicit_broadcast_remover.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
index 1985d20578..8b2df32567 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
@@ -160,6 +161,12 @@ StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayFor(
computed_array,
ComputeArrayForReshape(instr->shape(),
FindOrDie(cache_, instr->operand(0))));
+ } else if (instr->opcode() == HloOpcode::kDot) {
+ TF_ASSIGN_OR_RETURN(
+ computed_array,
+ ComputeArrayForDot(instr->shape(), instr->dot_dimension_numbers(),
+ FindOrDie(cache_, instr->operand(0)),
+ FindOrDie(cache_, instr->operand(1))));
} else {
computed_array = nullptr;
}
@@ -290,8 +297,7 @@ StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForGather(
}
if (auto* indexed = dynamic_cast<ScalarIndexedArray*>(source)) {
- auto it = c_find(indexed->output_dims(), source_dim);
- if (it != indexed->output_dims().end()) {
+ if (c_linear_search(indexed->output_dims(), source_dim)) {
return FoldGatherOfGather(indexed, indices, source_dim, output_dims,
shape);
}
@@ -956,11 +962,177 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseUnaryOp(HloOpcode opcode,
return Construct<ScalarIndexedConstantArray>(
new_source, scalar_indexed_const->indices(),
scalar_indexed_const->source_dim(),
- std::vector<int64>(scalar_indexed_const->output_dims().begin(),
- scalar_indexed_const->output_dims().end()),
+ ArraySliceToVector(scalar_indexed_const->output_dims()),
scalar_indexed_const->shape());
}
+namespace {
+
+// Returns the non-contracting non-batch dimension (as per `contracting_dims`
+// and `batch_dims`) if there is exactly one, otherwise returns nullopt.
+gtl::optional<int64> GetOnlyNonContractingNonBatchDim(
+ int64 rank, ArraySlice<int64> contracting_dims,
+ ArraySlice<int64> batch_dims) {
+ gtl::optional<int64> result;
+ for (int64 dim = 0; dim < rank; dim++) {
+ if (!ArrayContains(contracting_dims, dim) &&
+ !ArrayContains(batch_dims, dim)) {
+ if (result.has_value()) {
+ return gtl::nullopt;
+ }
+ result = dim;
+ }
+ }
+ return result;
+}
+
+// Returns true if `indexed_array`, which is either the LHS or the RHS of a Dot
+// HLO, can be folded into the dot operation. For now these conditions are both
+// necessary and sufficient.
+//
+// `tag` describes the caller. Used only for logging.
+//
+// `contracting_dims` and `batch_dims` are the contracting and batch dimensions
+// of whatever operand `indexed_array` is to the dot (LHS or RHS).
+bool CanFoldDotIntoIndexedArray(
+ tensorflow::StringPiece tag,
+ Analysis::ScalarIndexedConstantArray* indexed_array,
+ ArraySlice<int64> contracting_dims, ArraySlice<int64> batch_dims) {
+ gtl::optional<int64> non_contracting_non_batch_dim =
+ GetOnlyNonContractingNonBatchDim(ShapeUtil::Rank(indexed_array->shape()),
+ contracting_dims, batch_dims);
+ if (!non_contracting_non_batch_dim.has_value()) {
+ VLOG(3) << tag << ": multiple or no non-contracting non-batch dimensions";
+ return false;
+ }
+
+ if (indexed_array->output_dims().size() != 1 ||
+ indexed_array->output_dims()[0] != *non_contracting_non_batch_dim) {
+ VLOG(3) << tag << ": output dims != the lhs non-contracting non-batch dim";
+ return false;
+ }
+
+ int64 indexed_array_rank = ShapeUtil::Rank(indexed_array->shape());
+ if (indexed_array->source_dim() < (indexed_array_rank - 2)) {
+ // This restriction can be lifted by inserting reshape nodes.
+ VLOG(3) << tag
+ << ": source dim is not in the low two dims, won't be able to form "
+ "a matmul";
+ return false;
+ }
+
+ return true;
+}
+
+} // namespace
+
+StatusOr<Analysis::Array*>
+IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs(
+ const Shape& shape, const DotDimensionNumbers& dim_numbers,
+ ScalarIndexedConstantArray* lhs, ConstantArray* rhs) {
+ VLOG(3) << "ComputeArrayForDotWithIndexedLhs(" << ToString(lhs) << " "
+ << ToString(rhs);
+ if (!CanFoldDotIntoIndexedArray(
+ "ComputeArrayForDotWithIndexedLhs", lhs, /*contracting_dims=*/
+ AsInt64Slice(dim_numbers.lhs_contracting_dimensions()),
+ /*batch_dims=*/AsInt64Slice(dim_numbers.lhs_batch_dimensions()))) {
+ return nullptr;
+ }
+
+ int64 lhs_rank = ShapeUtil::Rank(lhs->shape());
+ DotDimensionNumbers new_dim_numbers = dim_numbers;
+ new_dim_numbers.set_lhs_contracting_dimensions(
+ 0, lhs->source_dim() == (lhs_rank - 1) ? (lhs_rank - 2) : (lhs_rank - 1));
+
+ TF_ASSIGN_OR_RETURN(Literal * literal_for_new_source,
+ TakeOwnership(HloEvaluator{}.EvaluateDotOp(
+ new_dim_numbers, lhs->literal(), *rhs->literal())));
+
+ // The new source dimension is wherever the non-batch non-contracting LHS
+ // dimension "went".
+ int64 new_source_dim = dim_numbers.lhs_batch_dimensions_size() +
+ dim_numbers.rhs_batch_dimensions_size();
+
+ ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source);
+ return Construct<ScalarIndexedConstantArray>(
+ new_source, lhs->indices(), new_source_dim,
+ ArraySliceToVector(lhs->output_dims()), shape);
+}
+
+StatusOr<Analysis::Array*>
+IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs(
+ const Shape& shape, const DotDimensionNumbers& dim_numbers,
+ ConstantArray* lhs, ScalarIndexedConstantArray* rhs) {
+ VLOG(3) << "ComputeArrayForDotWithIndexedRhs(" << ToString(lhs) << " "
+ << ToString(rhs);
+ if (!CanFoldDotIntoIndexedArray(
+ "ComputeArrayForDotWithIndexedRhs", rhs, /*contracting_dims=*/
+ AsInt64Slice(dim_numbers.rhs_contracting_dimensions()),
+ /*batch_dims=*/AsInt64Slice(dim_numbers.rhs_batch_dimensions()))) {
+ return nullptr;
+ }
+
+ int64 rhs_rank = ShapeUtil::Rank(rhs->shape());
+
+ DotDimensionNumbers new_dim_numbers = dim_numbers;
+ new_dim_numbers.set_rhs_contracting_dimensions(
+ 0, rhs->source_dim() == (rhs_rank - 1) ? (rhs_rank - 2) : (rhs_rank - 1));
+
+ TF_ASSIGN_OR_RETURN(Literal * literal_for_new_source,
+ TakeOwnership(HloEvaluator{}.EvaluateDotOp(
+ new_dim_numbers, *lhs->literal(), rhs->literal())));
+
+ // The new source dimension is wherever the non-batch non-contracting RHS
+ // dimension "went".
+ int64 new_source_dim = dim_numbers.lhs_batch_dimensions_size() +
+ dim_numbers.rhs_batch_dimensions_size() + 1;
+
+ ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source);
+ return Construct<ScalarIndexedConstantArray>(
+ new_source, rhs->indices(), new_source_dim,
+ ArraySliceToVector(rhs->output_dims()), shape);
+}
+
+StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForDot(
+ const Shape& shape, const DotDimensionNumbers& dim_numbers, Array* lhs,
+ Array* rhs) {
+ // Intuitively, if
+ //
+ // - The LHS of a dot product is a gathered sequence of rows from a constant
+ // array (i.e. LHS[I,J] = Const[Indices[I],J]) and the RHS is a constant
+ //
+ // OR
+ //
+ // - If the RHS of a dot product is a gathered sequence of columns from a
+ // constant array (i.e. RHS[I,J] = Const[I, Indices[J]]) and the LHS is a
+ // constant
+ //
+ // then the result of the dot product itself is a gather from a constant
+ // array. E.g. Dot(LHS, ConstRhs) where LHS[I,J] = Const[Indices[I],J] can be
+ // rewritten as Result where Result[I,J] = Dot(Const, ConstRhs)[Indices[I],
+ // J].
+ //
+ // We do a general version of this rewrite here.
+ VLOG(3) << "ComputeArrayForDot(" << ToString(lhs) << " " << ToString(rhs);
+ if (auto* lhs_indexed_array =
+ dynamic_cast<ScalarIndexedConstantArray*>(lhs)) {
+ if (auto* rhs_constant = dynamic_cast<ConstantArray*>(rhs)) {
+ return ComputeArrayForDotWithIndexedLhs(shape, dim_numbers,
+ lhs_indexed_array, rhs_constant);
+ }
+ }
+
+ if (auto* rhs_indexed_array =
+ dynamic_cast<ScalarIndexedConstantArray*>(rhs)) {
+ if (auto* lhs_constant = dynamic_cast<ConstantArray*>(lhs)) {
+ return ComputeArrayForDotWithIndexedRhs(shape, dim_numbers, lhs_constant,
+ rhs_indexed_array);
+ }
+ }
+
+ return nullptr;
+}
+
tensorflow::StringPiece IndexedArrayAnalysisPrinterPass::name() const {
return "indexed-array-analysis-printer-pass";
}
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h
index 8684430231..e923dc39f7 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.h
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h
@@ -268,6 +268,18 @@ class IndexedArrayAnalysis {
tensorflow::gtl::ArraySlice<int64> window_bounds, Array* source,
Array* indices);
+ StatusOr<Array*> ComputeArrayForDotWithIndexedLhs(
+ const Shape& shape, const DotDimensionNumbers& dim_numbers,
+ ScalarIndexedConstantArray* lhs, ConstantArray* rhs);
+
+ StatusOr<Array*> ComputeArrayForDotWithIndexedRhs(
+ const Shape& shape, const DotDimensionNumbers& dim_numbers,
+ ConstantArray* lhs, ScalarIndexedConstantArray* rhs);
+
+ StatusOr<Array*> ComputeArrayForDot(const Shape& shape,
+ const DotDimensionNumbers& dim_numbers,
+ Array* lhs, Array* rhs);
+
// This tries to fold a ScalarIndexedArray which has another
// ScalarIndexedArray as a source into a ScalarIndexedArray that instead has a
// ScalarIndexedArray as indices. If `source` happened to be a
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
index fc2befe05b..5f4b42799b 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
@@ -799,5 +799,170 @@ ENTRY main {
AssertArrayForRootExpressionIs(hlo_text, "%add");
}
+TEST_F(IndexedArrayAnalysisTest, DotOpBasic_0) {
+ string hlo_text = R"(
+HloModule DotOp
+
+ENTRY main {
+ gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{5,6,7,8},{9,10,11,12}})
+ dot_rhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}})
+ indices = s32[5] parameter(0)
+ dot_lhs = s32[5,4] gather(gather_operand, indices),
+ output_window_dims={1},
+ elided_window_dims={0},
+ gather_dims_to_operand_dims={0},
+ index_vector_dim=1,
+ window_bounds={1,4}
+ ROOT dot = s32[5,3] dot(dot_lhs, dot_rhs_constant), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+)";
+
+ AssertArrayWithConstantsForRootExpressionIs(hlo_text, R"(
+(scalar-indexed-const
+ (constant s32[3,3] s32[3,3] {
+ { 70, 80, 90 },
+ { 158, 184, 210 },
+ { 246, 288, 330 } })
+ %indices 0->[0]))");
+}
+
+TEST_F(IndexedArrayAnalysisTest, DotOpBasic_1) {
+ string hlo_text = R"(
+HloModule DotOp
+
+ENTRY main {
+ gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{5,6,7,8},{9,10,11,12}})
+ dot_rhs_constant = s32[3,3] constant(s32[3,3]{{1,2,3},{4,5,6},{7,8,9}})
+ indices = s32[5] parameter(0)
+ dot_lhs = s32[3,5] gather(gather_operand, indices),
+ output_window_dims={0},
+ elided_window_dims={1},
+ gather_dims_to_operand_dims={1},
+ index_vector_dim=1,
+ window_bounds={3,1}
+ ROOT dot = s32[5,3] dot(dot_lhs, dot_rhs_constant), lhs_contracting_dims={0}, rhs_contracting_dims={0}
+}
+)";
+
+ AssertArrayWithConstantsForRootExpressionIs(hlo_text, R"(
+(scalar-indexed-const
+ (constant s32[4,3] s32[4,3] {
+ { 84, 99, 114 },
+ { 96, 114, 132 },
+ { 108, 129, 150 },
+ { 120, 144, 168 } })
+ %indices 0->[1]))");
+}
+
+TEST_F(IndexedArrayAnalysisTest, DotOpBasic_2) {
+ string hlo_text = R"(
+HloModule DotOp
+
+ENTRY main {
+ gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{5,6,7,8},{9,10,11,12}})
+ dot_lhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}})
+ indices = s32[5] parameter(0)
+ dot_rhs = s32[3,5] gather(gather_operand, indices),
+ output_window_dims={0},
+ elided_window_dims={1},
+ gather_dims_to_operand_dims={1},
+ index_vector_dim=1,
+ window_bounds={3,1}
+ ROOT dot = s32[4,5] dot(dot_lhs_constant, dot_rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+)";
+
+ AssertArrayWithConstantsForRootExpressionIs(hlo_text, R"(
+(scalar-indexed-const
+ (constant s32[4,4] s32[4,4] {
+ { 38, 44, 50, 56 },
+ { 83, 98, 113, 128 },
+ { 128, 152, 176, 200 },
+ { 173, 206, 239, 272 } })
+ %indices 1->[1])
+)");
+}
+
+TEST_F(IndexedArrayAnalysisTest, DotOpBasic_3) {
+ string hlo_text = R"(
+HloModule DotOp
+
+ENTRY main {
+ gather_operand = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}})
+ dot_lhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}})
+ indices = s32[5] parameter(0)
+ dot_rhs = s32[5,3] gather(gather_operand, indices),
+ output_window_dims={1},
+ elided_window_dims={0},
+ gather_dims_to_operand_dims={0},
+ index_vector_dim=1,
+ window_bounds={1,3}
+ ROOT dot = s32[4,5] dot(dot_lhs_constant, dot_rhs), lhs_contracting_dims={1}, rhs_contracting_dims={1}
+}
+)";
+
+ AssertArrayWithConstantsForRootExpressionIs(hlo_text, R"(
+(scalar-indexed-const
+ (constant s32[4,4] s32[4,4] {
+ { 14, 32, 50, 68 },
+ { 32, 77, 122, 167 },
+ { 50, 122, 194, 266 },
+ { 68, 167, 266, 365 } })
+ %indices 1->[0])
+)");
+}
+
+TEST_F(IndexedArrayAnalysisTest, DotOpWithBatch) {
+ string hlo_text = R"(
+HloModule DotOp
+
+ENTRY main {
+ gather_operand = s32[2,3,2] constant(s32[2,3,2]{{{1,2},{3,4},{5,6}},{{7,8},{9,10},{11,12}}})
+ dot_lhs_constant = s32[2,2,3] constant(s32[2,2,3]{{{1,2,3},{4,5,6}},{{7,8,9},{10,11,12}}})
+ indices = s32[4] parameter(0)
+ dot_rhs = s32[2,3,4] gather(gather_operand, indices),
+ output_window_dims={0,1},
+ elided_window_dims={2},
+ gather_dims_to_operand_dims={2},
+ index_vector_dim=1,
+ window_bounds={2,3,1}
+ ROOT dot = s32[2,2,4] dot(dot_lhs_constant, dot_rhs),
+ lhs_contracting_dims={2}, rhs_contracting_dims={1},
+ lhs_batch_dims={0}, rhs_batch_dims={0}
+}
+)";
+
+ AssertArrayWithConstantsForRootExpressionIs(hlo_text, R"(
+(scalar-indexed-const
+ (constant s32[2,2,2] s32[2,2,2] {
+ { { 22, 28 },
+ { 49, 64 } },
+ { { 220, 244 },
+ { 301, 334 } } })
+ %indices 3->[2])
+)");
+}
+
+TEST_F(IndexedArrayAnalysisTest, DotOpNegative) {
+ string hlo_text = R"(
+HloModule DotOp
+
+ENTRY main {
+ gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{5,6,7,8},{9,10,11,12}})
+ dot_rhs_constant = s32[2,3] constant(s32[2,3]{{1,2,3},{4,5,6}})
+ indices = s32[2] parameter(0)
+ dot_lhs = s32[3,2] gather(gather_operand, indices),
+ output_window_dims={0},
+ elided_window_dims={1},
+ gather_dims_to_operand_dims={1},
+ index_vector_dim=1,
+ window_bounds={3,1}
+ ROOT dot = s32[3,3] dot(dot_lhs, dot_rhs_constant), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+)";
+
+ AssertArrayWithConstantsForRootExpressionIs(hlo_text, "%dot");
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/inliner_test.cc
index d2af261008..32937b33b3 100644
--- a/tensorflow/compiler/xla/service/inliner_test.cc
+++ b/tensorflow/compiler/xla/service/inliner_test.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include <memory>
#include <utility>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -51,10 +51,10 @@ TEST_F(InlinerTest, MapMax) {
auto max_f32 = max_builder.Build();
auto builder = HloComputation::Builder("MapMaxFunction");
- auto lhs = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>({1, 2, 3, 4})));
- auto rhs = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>({4, 3, 2, 1})));
+ auto lhs = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<float>({1, 2, 3, 4})));
+ auto rhs = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<float>({4, 3, 2, 1})));
builder.AddInstruction(
HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, max_f32.get()));
@@ -70,7 +70,7 @@ TEST_F(InlinerTest, MapMax) {
// Verify execution on CPU.
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
- auto expected = Literal::CreateR1<float>({4, 3, 3, 4});
+ auto expected = LiteralUtil::CreateR1<float>({4, 3, 3, 4});
EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
}
@@ -83,12 +83,12 @@ TEST_F(InlinerTest, MapConstant) {
HloInstruction::CreateParameter(0, r0f32, "x"));
(void)param1;
const2_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0f)));
auto const2_f32 = const2_builder.Build();
auto builder = HloComputation::Builder("MapConstFunction");
auto lhs = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1, 2, 3, 4}, {5, 6, 7, 8}})));
+ LiteralUtil::CreateR2<float>({{1, 2, 3, 4}, {5, 6, 7, 8}})));
builder.AddInstruction(
HloInstruction::CreateMap(lhs->shape(), {lhs}, const2_f32.get()));
@@ -104,7 +104,7 @@ TEST_F(InlinerTest, MapConstant) {
// Verify execution on CPU.
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
- auto expected = Literal::CreateR2<float>({{2, 2, 2, 2}, {2, 2, 2, 2}});
+ auto expected = LiteralUtil::CreateR2<float>({{2, 2, 2, 2}, {2, 2, 2, 2}});
EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
}
@@ -123,10 +123,10 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) {
auto max_f32 = max_builder.Build();
auto builder = HloComputation::Builder("MapSubFunction");
- auto lhs = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>({1, 2, 3, 4})));
- auto rhs = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>({4, 3, 2, 1})));
+ auto lhs = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<float>({1, 2, 3, 4})));
+ auto rhs = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<float>({4, 3, 2, 1})));
builder.AddInstruction(
HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, max_f32.get()));
@@ -142,7 +142,7 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) {
// Verify execution on CPU.
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
- auto expected = Literal::CreateR1<float>({3, 1, -1, -3});
+ auto expected = LiteralUtil::CreateR1<float>({3, 1, -1, -3});
EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
}
diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc
index bb7231c8c8..9e7a15f033 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc
@@ -167,7 +167,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusable) {
builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "1"));
HloInstruction* binary1 = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
- auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({}));
+ auto token = builder.AddInstruction(HloInstruction::CreateToken());
builder.AddInstruction(HloInstruction::CreateSend(binary1, token, 0));
HloInstruction* unary = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1));
@@ -356,7 +356,7 @@ TEST_F(InstructionFusionTest, AllowUnaryDuplication) {
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "0"));
HloInstruction* unary1 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kFloor, param0));
- auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({}));
+ auto token = builder.AddInstruction(HloInstruction::CreateToken());
builder.AddInstruction(HloInstruction::CreateSend(unary1, token, 0));
HloInstruction* unary2 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kAbs, unary1));
@@ -380,7 +380,7 @@ TEST_F(InstructionFusionTest, AllowEffectiveUnaryDuplication) {
builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "1"));
HloInstruction* binary1 = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
- auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({}));
+ auto token = builder.AddInstruction(HloInstruction::CreateToken());
builder.AddInstruction(HloInstruction::CreateSend(binary1, token, 0));
HloInstruction* unary = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1));
diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD
index 524d3234eb..8652599dc6 100644
--- a/tensorflow/compiler/xla/service/interpreter/BUILD
+++ b/tensorflow/compiler/xla/service/interpreter/BUILD
@@ -74,7 +74,7 @@ cc_library(
hdrs = ["executable.h"],
deps = [
":executor",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc
index 9816acf650..8d40c08d55 100644
--- a/tensorflow/compiler/xla/service/interpreter/executable.cc
+++ b/tensorflow/compiler/xla/service/interpreter/executable.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include <utility>
#include <vector>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index fedc83c8f8..46a6d57353 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -59,7 +59,6 @@ namespace xla {
// anonymous namespace, instead of three or four spread all over this file.
namespace {
-
} // namespace
std::ostream& operator<<(std::ostream& out,
@@ -113,14 +112,18 @@ LayoutConstraints::LayoutConstraints(
HloComputation* computation)
: points_to_analysis_(points_to_analysis), computation_(computation) {
// Gather all array-shaped logical buffers into unconstrained_buffer_ids.
- for (LogicalBuffer::Id id = 0; id < points_to_analysis_.num_logical_buffers();
- id++) {
- auto& buffer = points_to_analysis_.logical_buffer(id);
- // The points to analysis is computed per module, restrict constraints to
- // array buffers in this computation.
- if (buffer.IsArray() && buffer.instruction()->parent() == computation) {
- unconstrained_buffer_ids_.insert(buffer.id());
- }
+ for (HloInstruction* inst : computation_->instructions()) {
+ points_to_analysis_.GetPointsToSet(inst).ForEachElement(
+ [&](const ShapeIndex&, const PointsToSet::BufferList& buffers) {
+ for (const LogicalBuffer* buffer : buffers) {
+ // The points to analysis is computed per module, restrict
+ // constraints to array buffers in this computation.
+ if (buffer->IsArray() &&
+ buffer->instruction()->parent() == computation) {
+ unconstrained_buffer_ids_.insert(buffer->id());
+ }
+ }
+ });
}
}
diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc
index a673901c75..a16fa75e30 100644
--- a/tensorflow/compiler/xla/service/layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -141,9 +141,9 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) {
std::vector<std::initializer_list<int64>> minor_to_majors = {{0, 1}, {1, 0}};
for (auto& minor_to_major : minor_to_majors) {
auto builder = HloComputation::Builder(TestName());
- auto constant_literal1 = Literal::CreateR2WithLayout<float>(
+ auto constant_literal1 = LiteralUtil::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout(minor_to_major));
- auto constant_literal2 = Literal::CreateR2WithLayout<float>(
+ auto constant_literal2 = LiteralUtil::CreateR2WithLayout<float>(
{{5.0, 6.0}, {7.0, 8.0}}, LayoutUtil::MakeLayout(minor_to_major));
Shape ashape = constant_literal1->shape();
@@ -192,10 +192,10 @@ TEST_F(LayoutAssignmentTest, TupleLayout) {
// match their source).
auto builder = HloComputation::Builder(TestName());
auto constant0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant0, constant1}));
@@ -229,10 +229,10 @@ TEST_F(LayoutAssignmentTest, TupleSelect) {
// Verify layouts of a select with tuple operands is assigned properly.
auto builder = HloComputation::Builder(TestName());
auto constant0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
auto tuple0 = builder.AddInstruction(
HloInstruction::CreateTuple({constant0, constant1}));
@@ -240,7 +240,7 @@ TEST_F(LayoutAssignmentTest, TupleSelect) {
HloInstruction::CreateTuple({constant0, constant1}));
auto pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
auto select = builder.AddInstruction(HloInstruction::CreateTernary(
tuple0->shape(), HloOpcode::kSelect, pred, tuple0, tuple1));
@@ -274,7 +274,7 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) {
// tuple and assigning the layouts of the copied arrays as needed.
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
auto inner_tuple =
builder.AddInstruction(HloInstruction::CreateTuple({constant}));
auto nested_tuple = builder.AddInstruction(
@@ -584,7 +584,7 @@ TEST_F(LayoutAssignmentTest, TransposeToBitcastToUser) {
auto builder = HloComputation::Builder(TestName());
Shape input_shape = ShapeUtil::MakeShape(F32, {3, 5, 6, 7});
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
auto broadcast = builder.AddInstruction(
HloInstruction::CreateBroadcast(input_shape, constant, {}));
auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose(
@@ -770,8 +770,7 @@ TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) {
false_builder.AddInstruction(
HloInstruction::CreateParameter(0, tshape, "param"));
// Using infeed as layout assignment does not mess up with it.
- auto token =
- false_builder.AddInstruction(HloInstruction::CreateAfterAll({}));
+ auto token = false_builder.AddInstruction(HloInstruction::CreateToken());
auto infeed = false_builder.AddInstruction(
HloInstruction::CreateInfeed(xshape, token, ""));
auto infeed_data = false_builder.AddInstruction(
@@ -803,7 +802,7 @@ TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) {
TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) {
auto builder = HloComputation::Builder(TestName());
auto constant0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
builder.AddInstruction(HloInstruction::CreateUnary(
constant0->shape(), HloOpcode::kBitcast, constant0));
diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD
index f1e7fc2953..6f1e04a1c6 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/BUILD
+++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD
@@ -21,6 +21,11 @@ filegroup(
]),
)
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_cc_test",
+)
+
cc_library(
name = "alias_analysis",
srcs = ["alias_analysis.cc"],
@@ -37,12 +42,25 @@ cc_library(
],
)
+tf_cc_test(
+ name = "alias_analysis_test",
+ srcs = ["alias_analysis_test.cc"],
+ deps = [
+ ":alias_analysis",
+ "//tensorflow/compiler/xla/service:hlo_parser",
+ "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry",
+ "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test",
+ "//tensorflow/compiler/xla/tests:filecheck",
+ "//tensorflow/core:test",
+ ],
+)
+
cc_library(
name = "llvm_util",
srcs = ["llvm_util.cc"],
hdrs = ["llvm_util.h"],
deps = [
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
@@ -107,11 +125,30 @@ cc_library(
)
cc_library(
+ name = "kernel_tiling",
+ srcs = ["kernel_tiling.cc"],
+ hdrs = ["kernel_tiling.h"],
+ deps = [
+ ":ir_array",
+ ":llvm_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/core:lib",
+ "@llvm//:core",
+ ],
+)
+
+cc_library(
name = "fused_ir_emitter",
srcs = ["fused_ir_emitter.cc"],
hdrs = ["fused_ir_emitter.h"],
deps = [
":ir_array",
+ ":kernel_tiling",
":llvm_util",
":loop_emitter",
":tuple_ops",
diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc
index f200a08a3c..93a8c130e1 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc
@@ -35,9 +35,10 @@ void AliasAnalysis::AddAliasingInformationToIrArray(const HloInstruction& hlo,
llvm_ir::IrArray* array,
const ShapeIndex& index) {
BufferAllocation::Slice buffer_slice;
- if (hlo.opcode() == HloOpcode::kParameter) {
- // Parameters may alias with each other but may not alias with our temporary
- // buffers.
+ if (hlo.opcode() == HloOpcode::kParameter &&
+ hlo.parent() == hlo.parent()->parent()->entry_computation()) {
+ // Entry computation parameters may alias with each other but may not alias
+ // with our temporary buffers.
buffer_slice = BufferAllocation::Slice(kParameterAllocation, 0, 0);
} else {
const std::set<BufferAllocation::Slice> slices =
diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc
new file mode 100644
index 0000000000..2552ff4a6a
--- /dev/null
+++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc
@@ -0,0 +1,83 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <utility>
+
+#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
+#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h"
+#include "tensorflow/compiler/xla/tests/filecheck.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace cpu {
+namespace {
+class AliasAnalysisTest : public CpuCodegenTest {};
+
+void FakeCustomCallTarget(float* out, float** in) {}
+
+REGISTER_CUSTOM_CALL_TARGET(FakeCustomCallTarget);
+
+TEST_F(AliasAnalysisTest, EmbeddedComputationParamsMayAliasTemps) {
+ const char* hlo_string = R"(
+HloModule while
+
+body {
+ const.0.125 = f32[] constant(0.125)
+ body.state = f32[] parameter(0)
+ ROOT add.2.2 = f32[] add(const.0.125, body.state)
+}
+
+condition {
+ const.100 = f32[] constant(100)
+ condition.state = f32[] parameter(0)
+ addend = f32[] custom-call(condition.state), custom_call_target="FakeCustomCallTarget"
+ add = f32[] add(addend, condition.state)
+ ROOT greater-than = pred[] greater-than(const.100, add)
+}
+
+ENTRY while3 {
+ const.0 = f32[] constant(0)
+ ROOT while = f32[] while(const.0), condition=condition, body=body
+}
+)";
+
+ CompileAndVerifyIr(hlo_string, R"(
+; CHECK-LABEL: @body(i8* align 4 dereferenceable(4) %retval
+; CHECK: %[[add_result:.*]] = fadd fast float %[[fadd_lhs:.*]], %[[fadd_rhs:.*]]
+; CHECK: store float %[[add_result]], float* %[[store_dest:.*]], !alias.scope ![[alias_scope_md_for_store:.*]]
+;
+; CHECK-LABEL: @condition(i8* align 1 dereferenceable(1) %fusion, i8* noalias %run_options, i8** noalias %params
+; CHECK: %[[cond_state_buf_ptr:.*]] = getelementptr inbounds i8*, i8** %params, i64 0
+; CHECK: %[[cond_state_buf_untyped:.*]] = load i8*, i8** %[[cond_state_buf_ptr]]
+; CHECK: %[[cond_state_buf_typed:.*]] = bitcast i8* %[[cond_state_buf_untyped]] to float*
+; CHECK: load float, float* %[[cond_state_buf_typed]], !alias.scope ![[alias_scope_md_for_store]], !noalias ![[noalias_md_for_load:.*]]
+;
+; CHECK-LABEL: @while3(
+
+![[alias_scope_md_for_store]] = !{![[buffer_idx_0:.*]]}
+![[buffer_idx_0]] = !{!"buffer: {index:0, offset:0, size:4}", ![[aa_md_root:.*]]}
+![[aa_md_root]] = !{!"XLA global AA domain"}
+![[buffer_idx_1:.*]] = !{!"buffer: {index:1, offset:0, size:4}", !3}
+![[buffer_idx_1_offset_16:.*]] = !{!"buffer: {index:1, offset:16, size:1}", !3}
+![[noalias_md_for_load]] = !{![[buffer_idx_1_offset_16]], ![[buffer_idx_1]]}
+}
+)");
+}
+
+} // namespace
+} // namespace cpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc
index d909845a3a..b12ce97e28 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc
@@ -119,7 +119,24 @@ Status FusedIrEmitter::HandleGetTupleElement(
}
Status FusedIrEmitter::HandleParameter(HloInstruction* parameter) {
- generators_[parameter] = [=](const IrArray::Index& index) {
+ generators_[parameter] = [=](const IrArray::Index& index) -> llvm::Value* {
+ if (tiled_parameter_info_) {
+ if (llvm::Value* param_tile_buffer =
+ tiled_parameter_info_->GetBufferForParameter(
+ parameter->parameter_number())) {
+ // TODO(jlebar): Add AA metadata to this load. Tile buffers are global
+ // variables, so LLVM's points-to analysis doesn't help us much. And we
+ // want the AA info to be present before address spaces are inferred
+ // (which is pretty late in the pipeline), so even if we had
+ // address-space-based AA in LLVM, it wouldn't help us much here.
+ return ir_builder_->CreateLoad(
+ ir_builder_->CreateGEP(
+ param_tile_buffer,
+ {index.GetConstantWithIndexType(0), tiled_parameter_info_->x(),
+ tiled_parameter_info_->y()}),
+ "tiled_buffer");
+ }
+ }
return parameter_arrays_[parameter->parameter_number()]
.EmitReadArrayElement(index, ir_builder_);
};
diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h
index b3b6026ef1..a6ceec7b23 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h"
#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -56,6 +57,7 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault {
FusedIrEmitter(tensorflow::gtl::ArraySlice<llvm_ir::IrArray> parameter_arrays,
ElementalIrEmitter* elemental_emitter)
: parameter_arrays_(parameter_arrays),
+ tiled_parameter_info_(nullptr),
elemental_emitter_(elemental_emitter),
ir_builder_(elemental_emitter->ir_builder()),
module_(elemental_emitter->module()) {}
@@ -86,9 +88,14 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault {
return it->second;
}
+ void SetTiledParameterInfo(const llvm_ir::TiledParameterInfo* info) {
+ tiled_parameter_info_ = info;
+ }
+
private:
// Arrays of parameters of fusion instruction
tensorflow::gtl::ArraySlice<llvm_ir::IrArray> parameter_arrays_;
+ const llvm_ir::TiledParameterInfo* tiled_parameter_info_;
ElementalIrEmitter* elemental_emitter_;
diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc
index ea10cef49a..dcf9838d80 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc
@@ -422,9 +422,11 @@ IrArray IrArray::CastToShape(const Shape& new_shape,
llvm::IRBuilder<>* ir_builder) const {
llvm::Module* module = ir_builder->GetInsertBlock()->getParent()->getParent();
llvm::Type* new_ir_type = llvm_ir::ShapeToIrType(new_shape, module);
- return IrArray(
+ IrArray new_irarray(
ir_builder->CreatePointerCast(base_ptr_, new_ir_type->getPointerTo()),
new_shape);
+ new_irarray.metadata_ = metadata_;
+ return new_irarray;
}
/* static */ IrArray::Index IrArray::BumpIndex(const Index& index,
diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h
index 4648c6d7ac..0777c49923 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h
@@ -114,19 +114,19 @@ class IrArray {
size_t size() const { return multidim().size(); }
llvm::Value* operator[](size_t i) const { return multidim()[i]; }
- llvm::Value*& operator[](size_t i) { return multidim()[i]; }
+ llvm::Value*& operator[](size_t i) { return mutable_multidim()[i]; }
- void push_back(llvm::Value* value) { multidim().push_back(value); }
+ void push_back(llvm::Value* value) { mutable_multidim().push_back(value); }
void InsertAt(int64 index, llvm::Value* value) {
CHECK_LE(index, size());
- multidim().insert(multidim().begin() + index, value);
+ mutable_multidim().insert(mutable_multidim().begin() + index, value);
}
using iterator = std::vector<llvm::Value*>::iterator;
using const_iterator = std::vector<llvm::Value*>::const_iterator;
- iterator begin() { return multidim().begin(); }
- iterator end() { return multidim().end(); }
+ iterator begin() { return mutable_multidim().begin(); }
+ iterator end() { return mutable_multidim().end(); }
const_iterator begin() const { return multidim().begin(); }
const_iterator end() const { return multidim().end(); }
@@ -185,7 +185,7 @@ class IrArray {
private:
// Changing the multi-dimensional index invalidates the linear index.
- std::vector<llvm::Value*>& multidim() {
+ std::vector<llvm::Value*>& mutable_multidim() {
linear_ = nullptr;
return multidim_;
}
diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc
index 1f6e3c829f..98d0ceb3e2 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc
@@ -56,10 +56,11 @@ Status KernelSupportLibrary::For(
}
Status KernelSupportLibrary::If(
- llvm::Value* condition, const std::function<Status()>& true_block_generator,
+ tensorflow::StringPiece name, llvm::Value* condition,
+ const std::function<Status()>& true_block_generator,
const std::function<Status()>& false_block_generator) {
llvm_ir::LlvmIfData if_data =
- llvm_ir::EmitIfThenElse(condition, "", ir_builder_);
+ llvm_ir::EmitIfThenElse(condition, name, ir_builder_);
ir_builder_->SetInsertPoint(&if_data.true_block->back());
TF_RETURN_IF_ERROR(true_block_generator());
ir_builder_->SetInsertPoint(&if_data.false_block->back());
diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h
index 6f7a9d94e3..9d770cc4c3 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h
@@ -203,16 +203,30 @@ class KernelSupportLibrary {
// `true_block_generator()`;
// else
// `false_block_generator()`;
- Status If(llvm::Value* condition,
+ Status If(tensorflow::StringPiece name, llvm::Value* condition,
const std::function<Status()>& true_block_generator,
const std::function<Status()>& false_block_generator =
[]() -> Status { return Status::OK(); });
+ Status If(llvm::Value* condition,
+ const std::function<Status()>& true_block_generator,
+ const std::function<Status()>& false_block_generator =
+ []() -> Status { return Status::OK(); }) {
+ return If("", condition, true_block_generator, false_block_generator);
+ }
+
void IfReturnVoid(llvm::Value* condition,
const std::function<void()>& true_block_generator,
const std::function<void()>& false_block_generator = []() {
}) {
- TF_CHECK_OK(If(condition,
+ IfReturnVoid("", condition, true_block_generator, false_block_generator);
+ }
+
+ void IfReturnVoid(tensorflow::StringPiece name, llvm::Value* condition,
+ const std::function<void()>& true_block_generator,
+ const std::function<void()>& false_block_generator = []() {
+ }) {
+ TF_CHECK_OK(If(name, condition,
[&]() {
true_block_generator();
return Status::OK();
diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc
new file mode 100644
index 0000000000..533b75cdae
--- /dev/null
+++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc
@@ -0,0 +1,118 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h"
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+namespace llvm_ir {
+
+namespace {
+// Returns the indices of the first elements of all consecutive subarrays of the
+// given array. For example:
+// ConsecutiveSegments({m, m+1, m+2, n, k, k+1}) = {0, 3, 4}
+std::vector<size_t> ConsecutiveSegments(tensorflow::gtl::ArraySlice<int64> xs) {
+ std::vector<size_t> is = {0};
+ for (size_t i = 1; i < xs.size(); ++i) {
+ if (1 != xs[i] - xs[i - 1]) {
+ is.push_back(i);
+ }
+ }
+ return is;
+}
+
+// Merges the sequences of dimensions of the given shape which start at the
+// given indices `segs`.
+Shape MergeDimensions(tensorflow::gtl::ArraySlice<size_t> segs,
+ const Shape& shape) {
+ std::vector<int64> dimensions;
+ for (size_t i = 1; i <= segs.size(); ++i) {
+ dimensions.push_back(std::accumulate(
+ shape.dimensions().begin() + segs[i - 1],
+ shape.dimensions().begin() +
+ (segs.size() == i ? shape.dimensions().size() : segs[i]),
+ 1, std::multiplies<int64>()));
+ }
+ return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(),
+ dimensions);
+}
+} // namespace
+
+tensorflow::gtl::optional<std::vector<int64> > FindTranspose021(
+ const Shape& a, const Shape& b) {
+ if (!ShapeUtil::CompatibleIgnoringElementType(a, b)) {
+ return tensorflow::gtl::nullopt;
+ }
+
+ std::vector<int64> perm(a.dimensions().size());
+ {
+ auto layout_a_orig = LayoutUtil::MinorToMajor(a);
+ std::vector<int64> layout_a(layout_a_orig.rbegin(), layout_a_orig.rend());
+ auto layout_b_orig = LayoutUtil::MinorToMajor(b);
+ std::vector<int64> layout_b(layout_b_orig.rbegin(), layout_b_orig.rend());
+ for (size_t i = 0; i < perm.size(); ++i) {
+ perm[i] = PositionInContainer(layout_b, layout_a[i]);
+ }
+ }
+ auto segs = ConsecutiveSegments(perm);
+ if ((3 == segs.size() && 0 == perm[0]) || 2 == segs.size()) {
+ Shape norm_a =
+ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a);
+ Shape reduced_a = MergeDimensions(segs, norm_a);
+ auto reduced_a_dims = reduced_a.dimensions();
+ std::vector<int64> dims_021;
+ if (2 == segs.size()) {
+ // The logical component-0 is of size one.
+ dims_021 = {1, reduced_a_dims[1], reduced_a_dims[0]};
+ } else {
+ dims_021 = {reduced_a_dims[0], reduced_a_dims[2], reduced_a_dims[1]};
+ }
+
+ return dims_021;
+ }
+
+ return tensorflow::gtl::nullopt;
+}
+
+IrArray::Index GetUnreducedOutputIndex(
+ const IrArray::Index& reduced_output_index,
+ const Shape& reduced_output_shape, const Shape& unreduced_output_shape,
+ llvm::IRBuilder<>* ir_builder) {
+ auto bounds = reduced_output_shape.dimensions();
+ auto minor_to_major = reduced_output_shape.layout().minor_to_major();
+ llvm::Value* linear_index = reduced_output_index.GetConstantWithIndexType(0);
+ int64 multiplier = 1;
+ for (int i = 0; i < reduced_output_index.size(); ++i) {
+ int64 dim = minor_to_major[i];
+ llvm::Value* addend = ir_builder->CreateMul(
+ reduced_output_index[dim],
+ reduced_output_index.GetConstantWithIndexType(multiplier),
+ "linearizing",
+ /*HasNUW=*/true, /*HasNSW=*/true);
+ linear_index = ir_builder->CreateAdd(linear_index, addend, "",
+ /*HasNUW=*/true, /*HasNSW=*/true);
+ multiplier *= bounds[dim];
+ }
+
+ return IrArray::Index(linear_index, unreduced_output_shape, ir_builder);
+}
+
+} // namespace llvm_ir
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h
new file mode 100644
index 0000000000..6f1268fffb
--- /dev/null
+++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h
@@ -0,0 +1,80 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_TILING_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_TILING_H_
+
+#include "llvm/IR/Value.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
+
+namespace xla {
+namespace llvm_ir {
+
+// About 0-2-1 transpose:
+//
+// If a shape can be viewed as three logical components 0-1-2 in the order of
+// major to minor, a 0-2-1-transpose changes the order of such logical
+// components to 0-2-1. We call the shape being transposed the input shape and
+// the transposed shape the output shape. The logical view of the input and
+// output shapes for the transpose are called the 0-1-2 shape or reduced input
+// shape and the 0-2-1 shape or the reduced output shape respectively. The
+// original input and output shapes are called the unreduced input and output
+// shapes.
+
+// If `b` is a 0-2-1 transpose of `a` in 0-1-2, return the dimensions for the
+// reduced shape of `b` or the 0-2-1 shape.
+tensorflow::gtl::optional<std::vector<int64> > FindTranspose021(const Shape& a,
+ const Shape& b);
+
+// Return the unreduced output index corresponding to the given reduced output
+// index.
+IrArray::Index GetUnreducedOutputIndex(
+ const IrArray::Index& reduced_output_index,
+ const Shape& reduced_output_shape, const Shape& unreduced_output_shape,
+ llvm::IRBuilder<>* ir_builder);
+
+// A class to represent information for tiled parameters to support IR emission
+// for 021 transpose.
+class TiledParameterInfo {
+ public:
+ TiledParameterInfo(tensorflow::gtl::ArraySlice<llvm::Value*> param_buffers,
+ llvm::Value* y, llvm::Value* x)
+ : param_buffers_(param_buffers), y_(y), x_(x) {}
+
+ llvm::Value* x() const { return x_; }
+ llvm::Value* y() const { return y_; }
+
+ void set_x(llvm::Value* x) { x_ = x; }
+ void set_y(llvm::Value* y) { y_ = y; }
+
+ llvm::Value* GetBufferForParameter(int64 index) const {
+ return param_buffers_[index];
+ }
+
+ private:
+ // Param_buffers_[i] stores the tile buffer for the ith parameter or nullptr
+ // if the parameter is not tiled.
+ tensorflow::gtl::ArraySlice<llvm::Value*> param_buffers_;
+ // The y coordinate within a tile.
+ llvm::Value* y_;
+ // The x coordinate within a tile.
+ llvm::Value* x_;
+};
+
+} // namespace llvm_ir
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_TILING_H_
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
index 97bacc34b5..6c55361b44 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
@@ -26,7 +26,7 @@ limitations under the License.
#include "llvm/Target/TargetOptions.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/name_uniquer.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/types.h"
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h
index 4a10ec466d..9c51861eac 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h
@@ -27,7 +27,7 @@ limitations under the License.
#include "llvm/IR/Module.h"
#include "llvm/IR/Value.h"
#include "llvm/Support/raw_ostream.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/types.h"
diff --git a/tensorflow/compiler/xla/service/platform_util.cc b/tensorflow/compiler/xla/service/platform_util.cc
index 7c63c0acc7..39fe3c7835 100644
--- a/tensorflow/compiler/xla/service/platform_util.cc
+++ b/tensorflow/compiler/xla/service/platform_util.cc
@@ -75,19 +75,6 @@ PlatformUtil::GetSupportedPlatforms() {
auto* platform = platform_pair.second;
auto compiler_status = Compiler::GetForPlatform(platform);
if (compiler_status.ok()) {
- if (platform->VisibleDeviceCount() > 0) {
- LOG(INFO) << "platform " << platform->Name() << " present with "
- << platform->VisibleDeviceCount() << " visible devices";
- } else {
- LOG(WARNING) << "platform " << platform->Name() << " present but no "
- << "visible devices found";
- }
- // Note: currently we call zero device platforms "supported" on the basis
- // that, if the platform support was linked in, it was probably intended
- // to be used for execution, and this way we can flag an error.
- //
- // TODO(b/33730287) If we want an alternative version of this behavior we
- // could add an --xla_fallback_to_host flag.
platforms.push_back(platform);
} else {
LOG(INFO) << "platform " << platform->Name() << " present but no "
diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc
index 49ec38eb62..ca86c5d13e 100644
--- a/tensorflow/compiler/xla/service/reshape_mover.cc
+++ b/tensorflow/compiler/xla/service/reshape_mover.cc
@@ -38,7 +38,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/reshape_mover.h"
#include <algorithm>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc
index 13e2d3258e..ad3b662c20 100644
--- a/tensorflow/compiler/xla/service/reshape_mover_test.cc
+++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/reshape_mover.h"
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -175,8 +175,9 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMoved) {
TEST_F(ReshapeMoverTest, 1ConstantAnd2ReshapesMoved) {
HloComputation::Builder builder(TestName());
auto root_shape = ShapeUtil::MakeShape(F32, {2, 3});
- auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<bool>({{true, true, false}, {false, false, true}})));
+ auto const0 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2<bool>(
+ {{true, true, false}, {false, false, true}})));
auto param1 = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {1, 3, 1, 2}), "param1"));
@@ -255,12 +256,12 @@ TEST_F(ReshapeMoverTest, 2TrivialConstantReshapeNotMoved) {
HloComputation::Builder builder(TestName());
auto root_shape = ShapeUtil::MakeShape(F32, {3, 2});
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1, 2, 3}, {4, 5, 6}})));
+ LiteralUtil::CreateR2<float>({{1, 2, 3}, {4, 5, 6}})));
auto reshape0 =
builder.AddInstruction(HloInstruction::CreateReshape(root_shape, const0));
auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1, 2, 3}, {4, 5, 6}})));
+ LiteralUtil::CreateR2<float>({{1, 2, 3}, {4, 5, 6}})));
auto reshape1 =
builder.AddInstruction(HloInstruction::CreateReshape(root_shape, const1));
@@ -309,7 +310,7 @@ TEST_F(ReshapeMoverTest, 1NonTrivialReshapeMoved) {
auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {1, 3, 1, 2}), "param0"));
auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1, 2, 3}, {4, 5, 6}})));
+ LiteralUtil::CreateR2<float>({{1, 2, 3}, {4, 5, 6}})));
auto reshape0 =
builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0));
builder.AddInstruction(HloInstruction::CreateBinary(
@@ -348,7 +349,7 @@ TEST_F(ReshapeMoverTest, 1NonTrivialReshapeWith1ReshapedConstNotMoved) {
auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {1, 3}), "param0"));
auto const1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>({9, 8, 7})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({9, 8, 7})));
auto reshape0 =
builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0));
auto reshape1 =
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc
index bafe14d6f4..9b1ce143c6 100644
--- a/tensorflow/compiler/xla/service/shape_inference_test.cc
+++ b/tensorflow/compiler/xla/service/shape_inference_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <string>
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
@@ -1543,45 +1544,45 @@ class GatherShapeInferenceTest : public ShapeInferenceTest {
};
TEST_F(GatherShapeInferenceTest, TensorFlowGather) {
- TF_ASSERT_OK_AND_ASSIGN(
- Shape gather_shape,
- ShapeInference::InferGatherShape(matrix_64_48_, s64_vector_32_,
- HloInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{0},
- /*elided_window_dims=*/{1},
- /*gather_dims_to_operand_dims=*/{1},
- /*index_vector_dim=*/1),
- /*window_bounds=*/{64, 1}));
+ TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
+ ShapeInference::InferGatherShape(
+ matrix_64_48_, s64_vector_32_,
+ HloGatherInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{0},
+ /*elided_window_dims=*/{1},
+ /*gather_dims_to_operand_dims=*/{1},
+ /*index_vector_dim=*/1),
+ /*window_bounds=*/{64, 1}));
EXPECT_TRUE(
ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {64, 32})))
<< ShapeUtil::HumanString(gather_shape);
}
TEST_F(GatherShapeInferenceTest, TensorFlowGatherV2) {
- TF_ASSERT_OK_AND_ASSIGN(
- Shape gather_shape,
- ShapeInference::InferGatherShape(matrix_64_48_, s64_vector_32_,
- HloInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{1},
- /*elided_window_dims=*/{0},
- /*gather_dims_to_operand_dims=*/{0},
- /*index_vector_dim=*/1),
- /*window_bounds=*/{1, 48}));
+ TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
+ ShapeInference::InferGatherShape(
+ matrix_64_48_, s64_vector_32_,
+ HloGatherInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{1},
+ /*elided_window_dims=*/{0},
+ /*gather_dims_to_operand_dims=*/{0},
+ /*index_vector_dim=*/1),
+ /*window_bounds=*/{1, 48}));
EXPECT_TRUE(
ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {32, 48})))
<< ShapeUtil::HumanString(gather_shape);
}
TEST_F(GatherShapeInferenceTest, TensorFlowGatherNd) {
- TF_ASSERT_OK_AND_ASSIGN(
- Shape gather_shape,
- ShapeInference::InferGatherShape(matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
- HloInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4},
- /*elided_window_dims=*/{0},
- /*gather_dims_to_operand_dims=*/{0},
- /*index_vector_dim=*/4),
- /*window_bounds=*/{1, 48}));
+ TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
+ ShapeInference::InferGatherShape(
+ matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
+ HloGatherInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4},
+ /*elided_window_dims=*/{0},
+ /*gather_dims_to_operand_dims=*/{0},
+ /*index_vector_dim=*/4),
+ /*window_bounds=*/{1, 48}));
EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48})))
<< ShapeUtil::HumanString(gather_shape);
@@ -1592,7 +1593,7 @@ TEST_F(GatherShapeInferenceTest, TensorFlowBatchDynamicSlice) {
Shape gather_shape,
ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
- HloInstruction::MakeGatherDimNumbers(
+ HloGatherInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 7, 8},
/*elided_window_dims=*/{},
/*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
@@ -1609,7 +1610,7 @@ TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) {
Shape gather_shape,
ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_,
- HloInstruction::MakeGatherDimNumbers(
+ HloGatherInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 7, 8},
/*elided_window_dims=*/{},
/*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
@@ -1627,7 +1628,7 @@ TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) {
Shape gather_shape,
ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_5_10_9_7_6_,
- HloInstruction::MakeGatherDimNumbers(
+ HloGatherInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 7, 8},
/*elided_window_dims=*/{},
/*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
@@ -1646,7 +1647,7 @@ TEST_F(GatherShapeInferenceTest, NoOutputGatherDims) {
Shape gather_shape,
ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_vector_5_,
- HloInstruction::MakeGatherDimNumbers(
+ HloGatherInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{0, 1, 2, 3, 4},
/*elided_window_dims=*/{},
/*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
@@ -1664,7 +1665,7 @@ TEST_F(GatherShapeInferenceTest, ScalarGatherIndices) {
TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_scalar_,
- HloInstruction::MakeGatherDimNumbers(
+ HloGatherInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{0, 1, 2, 3},
/*elided_window_dims=*/{0},
/*gather_dims_to_operand_dims=*/{0},
@@ -1679,10 +1680,11 @@ TEST_F(GatherShapeInferenceTest, ScalarGatherIndices) {
TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
tuple_shape_, s64_vector_32_,
- HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0},
- /*elided_window_dims=*/{1},
- /*gather_dims_to_operand_dims=*/{1},
- /*index_vector_dim=*/1),
+ HloGatherInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{0},
+ /*elided_window_dims=*/{1},
+ /*gather_dims_to_operand_dims=*/{1},
+ /*index_vector_dim=*/1),
/*window_bounds=*/{64, 1});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
@@ -1693,10 +1695,11 @@ TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) {
TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
s64_vector_32_, tuple_shape_,
- HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0},
- /*elided_window_dims=*/{1},
- /*gather_dims_to_operand_dims=*/{1},
- /*index_vector_dim=*/0),
+ HloGatherInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{0},
+ /*elided_window_dims=*/{1},
+ /*gather_dims_to_operand_dims=*/{1},
+ /*index_vector_dim=*/0),
/*window_bounds=*/{64, 1});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
@@ -1707,10 +1710,11 @@ TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) {
TEST_F(GatherShapeInferenceTest, FloatingPointGatherIndicesInput) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
s64_vector_32_, vector_32_,
- HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0},
- /*elided_window_dims=*/{1},
- /*gather_dims_to_operand_dims=*/{1},
- /*index_vector_dim=*/0),
+ HloGatherInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{0},
+ /*elided_window_dims=*/{1},
+ /*gather_dims_to_operand_dims=*/{1},
+ /*index_vector_dim=*/0),
/*window_bounds=*/{64, 1});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
@@ -1722,7 +1726,7 @@ TEST_F(GatherShapeInferenceTest,
InvalidGatherDimNumbers_NonAscendingWindowIndices) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
- HloInstruction::MakeGatherDimNumbers(
+ HloGatherInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 8, 7},
/*elided_window_dims=*/{},
/*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
@@ -1739,7 +1743,7 @@ TEST_F(GatherShapeInferenceTest,
InvalidGatherDimNumbers_RepeatedWindowIndices) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
- HloInstruction::MakeGatherDimNumbers(
+ HloGatherInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 7, 7},
/*elided_window_dims=*/{},
/*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
@@ -1756,7 +1760,7 @@ TEST_F(GatherShapeInferenceTest,
InvalidGatherDimNumbers_WindowIndexOutOfBounds) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
- HloInstruction::MakeGatherDimNumbers(
+ HloGatherInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 99, 100, 101},
/*elided_window_dims=*/{},
/*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
@@ -1772,7 +1776,7 @@ TEST_F(GatherShapeInferenceTest,
InvalidGatherDimNumbers_WindowIndexBarelyOutOfBounds) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
- HloInstruction::MakeGatherDimNumbers(
+ HloGatherInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 7, 9},
/*elided_window_dims=*/{},
/*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
@@ -1788,7 +1792,7 @@ TEST_F(GatherShapeInferenceTest,
InvalidGatherDimNumbers_MismatchingElidedWindowDims) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
- HloInstruction::MakeGatherDimNumbers(
+ HloGatherInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 7, 8},
/*elided_window_dims=*/{4},
/*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
@@ -1806,7 +1810,7 @@ TEST_F(GatherShapeInferenceTest,
InvalidGatherDimNumbers_OutOfBoundsWindowToInputMapping) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
- HloInstruction::MakeGatherDimNumbers(
+ HloGatherInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 7, 8},
/*elided_window_dims=*/{0, 1, 2, 3, 19},
/*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
@@ -1823,7 +1827,7 @@ TEST_F(GatherShapeInferenceTest,
InvalidGatherDimNumbers_RepeatedWindowToInputMapping) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
- HloInstruction::MakeGatherDimNumbers(
+ HloGatherInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 7, 8},
/*elided_window_dims=*/{0, 1, 2, 3, 3},
/*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
@@ -1841,7 +1845,7 @@ TEST_F(GatherShapeInferenceTest,
InvalidGatherDimNumbers_MismatchingGatherToInputMapping) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
- HloInstruction::MakeGatherDimNumbers(
+ HloGatherInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 7, 8},
/*elided_window_dims=*/{},
/*gather_dims_to_operand_dims=*/{0, 1, 2, 3},
@@ -1860,7 +1864,7 @@ TEST_F(GatherShapeInferenceTest,
InvalidGatherDimNumbers_OutOfBoundsGatherToInputMapping) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
- HloInstruction::MakeGatherDimNumbers(
+ HloGatherInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 7, 8},
/*elided_window_dims=*/{},
/*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 7},
@@ -1878,7 +1882,7 @@ TEST_F(GatherShapeInferenceTest,
InvalidGatherDimNumbers_RepeatedGatherToInputMapping) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
- HloInstruction::MakeGatherDimNumbers(
+ HloGatherInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 7, 8},
/*elided_window_dims=*/{},
/*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 3},
@@ -1896,7 +1900,7 @@ TEST_F(GatherShapeInferenceTest,
InvalidGatherDimNumbers_NonAscendingElidedWindowDims) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
- HloInstruction::MakeGatherDimNumbers(
+ HloGatherInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 7, 8},
/*elided_window_dims=*/{2, 1},
/*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
@@ -1911,7 +1915,7 @@ TEST_F(GatherShapeInferenceTest,
TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsTooLarge) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
- HloInstruction::MakeGatherDimNumbers(
+ HloGatherInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 7},
/*elided_window_dims=*/{2},
/*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
@@ -1928,7 +1932,7 @@ TEST_F(GatherShapeInferenceTest,
InvalidGatherDimNumbers_MismatchingNumberOfWindowBounds) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
- HloInstruction::MakeGatherDimNumbers(
+ HloGatherInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 7, 8},
/*elided_window_dims=*/{},
/*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
@@ -1946,7 +1950,7 @@ TEST_F(GatherShapeInferenceTest,
InvalidGatherDimNumbers_WindowBoundsNot1ForElidedDim) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
- HloInstruction::MakeGatherDimNumbers(
+ HloGatherInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 7},
/*elided_window_dims=*/{1},
/*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
@@ -1962,7 +1966,7 @@ TEST_F(GatherShapeInferenceTest,
TEST_F(GatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_,
- HloInstruction::MakeGatherDimNumbers(
+ HloGatherInstruction::MakeGatherDimNumbers(
/*output_window_dims=*/{4, 5, 6, 7, 8},
/*elided_window_dims=*/{},
/*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc
index 4c5038a009..7232c658b3 100644
--- a/tensorflow/compiler/xla/service/transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/transfer_manager.cc
@@ -44,6 +44,7 @@ StatusOr<std::unique_ptr<Literal>> TransferManager::TransferLiteralFromDevice(
se::Stream* stream, const ShapedBuffer& device_buffer) {
StatusOr<std::unique_ptr<Literal>> ret;
se::Stream* substream = stream->GetOrCreateSubStream();
+ substream->ThenWaitFor(stream);
auto cleanup = tensorflow::gtl::MakeCleanup(
[&]() { stream->ReturnSubStream(substream); });
@@ -64,6 +65,7 @@ Status TransferManager::TransferLiteralToDevice(
// Use a substream so that if we are called from a HostCallback we don't
// deadlock.
se::Stream* substream = stream->GetOrCreateSubStream();
+ substream->ThenWaitFor(stream);
auto cleanup = tensorflow::gtl::MakeCleanup(
[&]() { stream->ReturnSubStream(substream); });
TF_RETURN_IF_ERROR(
diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h
index e384359642..82c599e482 100644
--- a/tensorflow/compiler/xla/service/transfer_manager.h
+++ b/tensorflow/compiler/xla/service/transfer_manager.h
@@ -20,7 +20,7 @@ limitations under the License.
#include <set>
#include <vector>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
@@ -167,16 +167,6 @@ class TransferManager {
const se::Platform* platform);
protected:
- // Transfer a memory block of the given size from 'source' buffer to the
- // Infeed interface of the device using the given executor.
- //
- // size is the size to transfer from source in bytes.
- //
- // source is the source data that must be in the target-dependent layout that
- // the Infeed HLO used in the computation expects.
- virtual Status TransferBufferToInfeed(se::StreamExecutor* executor,
- int64 size, const void* source) = 0;
-
// Transfer a memory block of the given size from the device source into the
// 'destination' buffer.
//
diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc
index cccb8f2fbb..7051a4cf51 100644
--- a/tensorflow/compiler/xla/service/transpose_folding_test.cc
+++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -160,11 +160,11 @@ TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) {
auto builder = HloComputation::Builder("entry");
// (1.0 + 2.0) * (2.0 - 3.0)
HloInstruction* const1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
HloInstruction* const2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
HloInstruction* const3 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary(
const1->shape(), HloOpcode::kAdd, const1, const2));
HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary(
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
index 226d0af5d2..0ac8df4271 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
@@ -124,9 +124,9 @@ class TuplePointsToAnalysisTest : public HloTestBase {
TEST_F(TuplePointsToAnalysisTest, SimpleTuple) {
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
@@ -177,14 +177,14 @@ TEST_F(TuplePointsToAnalysisTest, NestedTuple) {
// tuple.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto inner_tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto constant3 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({inner_tuple, constant3}));
@@ -238,14 +238,14 @@ TEST_F(TuplePointsToAnalysisTest, GetTupleElement) {
// tuple.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto inner_tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto constant3 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({inner_tuple, constant3}));
@@ -270,7 +270,7 @@ TEST_F(TuplePointsToAnalysisTest, DuplicatedElement) {
// Create a tuple which contains duplicate elements.
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant, constant, constant}));
@@ -291,9 +291,9 @@ TEST_F(TuplePointsToAnalysisTest, TupleCopy) {
// the same.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto copy = builder.AddInstruction(
@@ -317,8 +317,8 @@ TEST_F(TuplePointsToAnalysisTest, SendAndSendDone) {
// Send forwards its operand to the output tuple at {0}.
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
- auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({}));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
+ auto token = builder.AddInstruction(HloInstruction::CreateToken());
auto send = builder.AddInstruction(
HloInstruction::CreateSend(constant, token, /*channel_id=*/0));
auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send));
@@ -343,7 +343,7 @@ TEST_F(TuplePointsToAnalysisTest, SendAndSendDone) {
TEST_F(TuplePointsToAnalysisTest, RecvAndRecvDone) {
// RecvDone forwards its operand tuple element at {0} to the output.
auto builder = HloComputation::Builder(TestName());
- auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({}));
+ auto token = builder.AddInstruction(HloInstruction::CreateToken());
auto recv = builder.AddInstruction(HloInstruction::CreateRecv(
ShapeUtil::MakeShape(F32, {1, 2, 3}), token, /*channel_id=*/0));
auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv));
@@ -365,16 +365,16 @@ TEST_F(TuplePointsToAnalysisTest, TupleSelect) {
// set containing the union of both sides.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto tuple1 = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto tuple2 = builder.AddInstruction(
HloInstruction::CreateTuple({constant2, constant2}));
auto pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
auto select = builder.AddInstruction(HloInstruction::CreateTernary(
tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2));
@@ -403,7 +403,7 @@ TEST_F(TuplePointsToAnalysisTest, SelectTupleParameters) {
auto param1 = builder.AddInstruction(
HloInstruction::CreateParameter(1, tuple_shape, "param1"));
auto pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
auto select = builder.AddInstruction(HloInstruction::CreateTernary(
tuple_shape, HloOpcode::kTupleSelect, pred, param0, param1));
auto copy = builder.AddInstruction(
@@ -443,16 +443,16 @@ TEST_F(TuplePointsToAnalysisTest, UnambiguousTupleSelect) {
// Select from two identical tuples. The result should not be ambiguous.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto tuple1 = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto tuple2 = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
auto select = builder.AddInstruction(HloInstruction::CreateTernary(
tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2));
@@ -474,9 +474,9 @@ TEST_F(TuplePointsToAnalysisTest, NestedTupleSelect) {
// the right values.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto inner_tuple1 = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto inner_tuple2 = builder.AddInstruction(
@@ -488,7 +488,7 @@ TEST_F(TuplePointsToAnalysisTest, NestedTupleSelect) {
builder.AddInstruction(HloInstruction::CreateTuple({inner_tuple2}));
auto pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
auto select = builder.AddInstruction(HloInstruction::CreateTernary(
tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2));
@@ -521,9 +521,9 @@ TEST_F(TuplePointsToAnalysisTest, TupleWithBitcast) {
// have the operand of the bitcast in its points-to set.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary(
constant2->shape(), HloOpcode::kBitcast, constant2));
auto tuple =
@@ -557,9 +557,10 @@ TEST_F(TuplePointsToAnalysisTest, PointsToTupleConstantElements) {
// Construct a tuple constant and kCopy it. Verify the points-to set of the
// copy correctly correctly points into the nested elements of the constant.
auto builder = HloComputation::Builder(TestName());
- auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::MakeTuple({Literal::CreateR2<float>({{1.0}, {2.0}}).get(),
- Literal::CreateR1<float>({2.0, 42}).get()})));
+ auto tuple_constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR2<float>({{1.0}, {2.0}}).get(),
+ LiteralUtil::CreateR1<float>({2.0, 42}).get()})));
auto copy = builder.AddInstruction(HloInstruction::CreateUnary(
tuple_constant->shape(), HloOpcode::kCopy, tuple_constant));
@@ -579,9 +580,9 @@ TEST_F(TuplePointsToAnalysisTest, BufferAliases) {
// times. Verify buffer alias sets.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto inner_tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto tuple = builder.AddInstruction(
@@ -620,7 +621,7 @@ class FusionPointsToAnalysisTest : public TuplePointsToAnalysisTest {
auto tuple_element1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(update_shape, tuple_param0, 1));
auto ones = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({1.f, 1.f, 1.f, 1.f})));
+ LiteralUtil::CreateR1<float>({1.f, 1.f, 1.f, 1.f})));
// Create 'update' = Add(GetTupleElement(tuple_param0, 1), ones)
auto update = builder.AddInstruction(HloInstruction::CreateBinary(
update_shape, HloOpcode::kAdd, tuple_element1, ones));
@@ -868,9 +869,9 @@ TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) {
// Create a DynamicUpdateSlice instruction of tuple element 1.
auto starts = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int32>({2})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({2})));
auto update = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({2.f, 2.f, 2.f})));
+ LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
auto dynamic_update_slice =
builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
data_shape, gte1, update, starts));
@@ -962,9 +963,9 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) {
// Create a DynamicUpdateSlice instruction of tuple element 1.
auto starts = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int32>({2})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({2})));
auto update = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({2.f, 2.f, 2.f})));
+ LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
auto dynamic_update_slice =
builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
data_shape, gte1, update, starts));
@@ -1016,9 +1017,9 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) {
Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
auto a = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1.0, 0.0}, {0.0, 1.0}})));
+ LiteralUtil::CreateR2<float>({{1.0, 0.0}, {0.0, 1.0}})));
auto b = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
+ LiteralUtil::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
@@ -1027,7 +1028,7 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) {
HloInstruction::CreateDot(data_shape, a, b, dot_dnums));
auto one = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto add_operand = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape, one, {1}));
@@ -1049,7 +1050,7 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) {
Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
auto one = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto operand = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape, one, {1}));
@@ -1057,7 +1058,7 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) {
HloInstruction::CreateReverse(data_shape, operand, {0, 1}));
auto two = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
+ LiteralUtil::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, reverse, two));
@@ -1122,7 +1123,7 @@ TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) {
auto sub_param = sub_builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "sub_param"));
auto one = sub_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto ones = sub_builder.AddInstruction(
HloInstruction::CreateBroadcast(shape, one, {1}));
auto add = sub_builder.AddInstruction(
diff --git a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc
index d3635eae81..39b693872d 100644
--- a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include <memory>
#include <utility>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc
index 23519e445e..32e69c335b 100644
--- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc
+++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc
@@ -53,7 +53,7 @@ HloComputation* WhileLoopInvariantCodeMotionTest::MakeAlwaysTrueComputation(
builder.AddInstruction(
HloInstruction::CreateParameter(0, param_shape, "param"));
builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
return module->AddEmbeddedComputation(builder.Build());
}
@@ -125,7 +125,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, HoistInvariantOperationTree) {
builder.AddInstruction(HloInstruction::CreateUnary(
scalar_s32, HloOpcode::kNegate, mul_result));
HloInstruction* constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(4)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(4)));
HloInstruction* sub_result =
builder.AddInstruction(HloInstruction::CreateBinary(
scalar_s32, HloOpcode::kSubtract, negate_result, constant));
@@ -273,7 +273,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistInstructionWithSideEffects) {
HloComputation::Builder builder(TestName());
auto* scalar_param = builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_s32, "param"));
- auto* token = builder.AddInstruction(HloInstruction::CreateAfterAll({}));
+ auto* token = builder.AddInstruction(HloInstruction::CreateToken());
auto* init_value = builder.AddInstruction(
HloInstruction::CreateTuple({scalar_param, scalar_param, token}));
auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile(
@@ -323,7 +323,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) {
HloComputation::Builder builder(TestName());
auto* scalar_param = builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_s32, "param"));
- auto* token = builder.AddInstruction(HloInstruction::CreateAfterAll({}));
+ auto* token = builder.AddInstruction(HloInstruction::CreateToken());
auto* init_value = builder.AddInstruction(
HloInstruction::CreateTuple({scalar_param, scalar_param, token}));
auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile(
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
index 3c83049216..2e1571943e 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
@@ -157,7 +157,7 @@ TEST_F(WhileLoopSimplifierTest,
auto* while_op = computation->root_instruction();
ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile);
auto* true_op = while_op->while_body()->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
TF_ASSERT_OK(true_op->AddControlDependencyTo(
while_op->while_body()->root_instruction()));
ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie());
@@ -175,10 +175,10 @@ TEST_F(WhileLoopSimplifierTest, LoopWithSendNotSimplified) {
auto* while_op = computation->root_instruction();
ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile);
auto* while_body = while_op->while_body();
- auto* token = while_body->AddInstruction(HloInstruction::CreateAfterAll({}));
+ auto* token = while_body->AddInstruction(HloInstruction::CreateToken());
auto* send = while_body->AddInstruction(HloInstruction::CreateSend(
while_body->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(true))),
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true))),
token,
/*channel_id=*/0));
while_body->AddInstruction(HloInstruction::CreateSendDone(send));
@@ -192,7 +192,7 @@ TEST_F(WhileLoopSimplifierTest, LoopWithRecvNotSimplified) {
auto* while_op = computation->root_instruction();
ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile);
auto* while_body = while_op->while_body();
- auto* token = while_body->AddInstruction(HloInstruction::CreateAfterAll({}));
+ auto* token = while_body->AddInstruction(HloInstruction::CreateToken());
auto* recv = while_body->AddInstruction(
HloInstruction::CreateRecv(ShapeUtil::MakeShape(F32, {1}), token,
/*channel_id=*/0));
@@ -211,7 +211,7 @@ TEST_F(WhileLoopSimplifierTest, LoopWithInfeedNotSimplified) {
auto* while_op = computation->root_instruction();
ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile);
auto* while_body = while_op->while_body();
- auto token = while_body->AddInstruction(HloInstruction::CreateAfterAll({}));
+ auto token = while_body->AddInstruction(HloInstruction::CreateToken());
while_body->AddInstruction(HloInstruction::CreateInfeed(
ShapeUtil::MakeShape(F32, {1}), token, "config"));
EXPECT_FALSE(WhileLoopSimplifier().Run(the_module).ValueOrDie());
diff --git a/tensorflow/compiler/xla/service/while_util.cc b/tensorflow/compiler/xla/service/while_util.cc
index 473eab2ea8..1ef17b9d7d 100644
--- a/tensorflow/compiler/xla/service/while_util.cc
+++ b/tensorflow/compiler/xla/service/while_util.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/while_util.h"
+#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
#include "tensorflow/compiler/xla/service/tuple_util.h"
@@ -38,7 +39,7 @@ static StatusOr<HloComputation*> WidenWhileCondition(
// the root instruction later. We later change the root instruction to
// something more appropriate.
builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
return narrow_condition->parent()->AddEmbeddedComputation(builder.Build());
}();
@@ -154,7 +155,7 @@ MakeCountedLoopConditionComputation(const Shape& loop_state_shape,
{&loop_state_shape}, scalar_pred, "while_cond"));
HloInstruction* trip_count_constant = cond_computation->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(trip_count)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(trip_count)));
HloInstruction* param = cond_computation->parameter_instruction(0);
TF_ASSIGN_OR_RETURN(HloInstruction * indvar,
@@ -175,7 +176,7 @@ static StatusOr<std::unique_ptr<HloComputation>> MakeCountedLoopBodyComputation(
CreateComputationWithSignature(
{&loop_state_shape}, loop_state_shape, "while_body"));
HloInstruction* one = body_computation->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(1)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
HloInstruction* param = body_computation->parameter_instruction(0);
TF_ASSIGN_OR_RETURN(HloInstruction * indvar,
MakeGetTupleElementHlo(param, 0));
@@ -203,7 +204,7 @@ static StatusOr<HloInstruction*> MakeInitTupleFromInitValues(
std::vector<HloInstruction*> init_values_with_indvar;
init_values_with_indvar.reserve(init_values.size() + 1);
HloInstruction* zero = computation->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
init_values_with_indvar.push_back(zero);
c_copy(init_values, std::back_inserter(init_values_with_indvar));
return computation->AddInstruction(
diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc
index 44b0ec5cd4..83d696fe09 100644
--- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc
+++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -32,7 +32,8 @@ StatusOr<bool> ZeroSizedHloElimination::Run(HloModule* module) {
for (HloComputation* comp : module->MakeNonfusionComputations()) {
for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) {
if (instruction->HasSideEffect() ||
- !ShapeUtil::IsArray(instruction->shape())) {
+ !ShapeUtil::IsArray(instruction->shape()) ||
+ instruction->opcode() == HloOpcode::kConstant) {
continue;
}
if (comp->IsRemovable(instruction) &&
diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc
index c6bd013a1a..b9ef18892d 100644
--- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc
+++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include <unordered_set>
#include <vector>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@@ -67,12 +67,19 @@ TEST_F(ZeroSizedHloEliminationTest, DoesNotEliminateParameter) {
}
TEST_F(ZeroSizedHloEliminationTest, DoesNotEliminateSideEffects) {
- auto token = builder_.AddInstruction(HloInstruction::CreateAfterAll({}));
+ auto token = builder_.AddInstruction(HloInstruction::CreateToken());
builder_.AddInstruction(
HloInstruction::CreateSend(zero_sized_param_, token, 0));
TF_ASSERT_OK_AND_ASSIGN(bool changed, RunZeroSizedElimination());
EXPECT_FALSE(changed);
}
+TEST_F(ZeroSizedHloEliminationTest, DoesNotEliminateConstant) {
+ builder_.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1({})));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, RunZeroSizedElimination());
+ EXPECT_FALSE(changed);
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/shape_layout.cc b/tensorflow/compiler/xla/shape_layout.cc
index 7ee366b27a..caad31d6ce 100644
--- a/tensorflow/compiler/xla/shape_layout.cc
+++ b/tensorflow/compiler/xla/shape_layout.cc
@@ -67,6 +67,14 @@ void ShapeLayout::ResetLayout(const Layout& layout) {
TF_CHECK_OK(ShapeUtil::ValidateShape(shape_));
}
+void ShapeLayout::ResetLayout(const Layout& layout,
+ ShapeIndexView shape_index) {
+ CHECK(ShapeUtil::IsTuple(shape_));
+ *ShapeUtil::GetMutableSubshape(&shape_, shape_index)->mutable_layout() =
+ layout;
+ TF_CHECK_OK(ShapeUtil::ValidateShape(shape_));
+}
+
bool ShapeLayout::operator==(const ShapeLayout& other) const {
return ShapeUtil::Equal(shape_, other.shape_);
}
diff --git a/tensorflow/compiler/xla/shape_layout.h b/tensorflow/compiler/xla/shape_layout.h
index 36806da599..214cf98854 100644
--- a/tensorflow/compiler/xla/shape_layout.h
+++ b/tensorflow/compiler/xla/shape_layout.h
@@ -72,6 +72,10 @@ class ShapeLayout {
// tuple.
void ResetLayout(const Layout& layout);
+ // Resets the layout on the shape at the provided ShapeIndex to the provided
+ // layout. Shape must be a tuple.
+ void ResetLayout(const Layout& layout, ShapeIndexView shape_index);
+
// Returns a string representation of this object.
string ToString() const { return ShapeUtil::HumanStringWithLayout(shape_); }
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index 56d24423c4..f4668c0f55 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -46,28 +46,14 @@ namespace xla {
using ::tensorflow::strings::StrAppend;
using ::tensorflow::strings::StrCat;
-string ShapeIndex::ToString() const {
- return StrCat("{", tensorflow::str_util::Join(indices_, ","), "}");
-}
+string ShapeIndex::ToString() const { return ShapeIndexView(*this).ToString(); }
string ShapeIndexView::ToString() const {
- return StrCat("{",
- tensorflow::str_util::Join(
- tensorflow::gtl::make_range(begin_, end_), ","),
- "}");
+ return StrCat("{", tensorflow::str_util::Join(indices_, ","), "}");
}
bool ShapeIndexView::operator==(const ShapeIndexView& other) const {
- if (size() != other.size()) {
- return false;
- }
- for (auto it = begin(), other_it = other.begin(); it != end();
- ++it, ++other_it) {
- if (*it != *other_it) {
- return false;
- }
- }
- return true;
+ return indices_ == other.indices_;
}
bool ShapeIndexView::operator!=(const ShapeIndexView& other) const {
@@ -1126,12 +1112,41 @@ Status ForEachMutableSubshapeHelper(
for (auto dim : Permute(permutation, shape.dimensions())) {
new_shape.add_dimensions(dim);
}
+
+ // If `shape` has a layout, by contract we choose a new layout such that the
+ // transpose defined by this permutation is a bitcast.
+ //
+ // Some formalism helps to understand the correct way to do this. We're going
+ // to do algebra in the group of permutations of the dimensions of `shape`.
+ //
+ // Since the order of `shape`'s dimensions is not permuted relative to itself,
+ // `shape`'s list of dimensions is isomorphic to the identity I.
+ //
+ // Let `shape`'s layout be L. A layout is a permutation which maps a
+ // minor-to-major physical layout to the order of a shape's logical dims.
+ // Therefore inverse of a layout maps from logical to physical dims, and so
+ // the physical layout of I is simply L'.I = L', where L' is the inverse of L.
+ //
+ // Let the argument `permutation` be P. This is a permutation over `shape`'s
+ // dimensions, so our return value will be a shape with dims P.I = P. Our
+ // goal is to construct a layout permutation L* that we can apply to P such
+ // that that the physical dimension ordering of the returned shape is the same
+ // as that of the original shape, namely L'.
+ //
+ // Our returned shape has dims P and layout L*, so its in-memory layout is
+ // L*'.P. Setting this equal to L' and solving for L*, we get:
+ //
+ // L*'.P = L' =>
+ // L*' = L'P' =>
+ // L* = P.L
+ //
if (shape.has_layout()) {
CHECK(LayoutUtil::IsDenseArray(shape));
Layout* new_layout = new_shape.mutable_layout();
new_layout->set_format(DENSE);
new_layout->clear_minor_to_major();
- for (auto index : Permute(permutation, shape.layout().minor_to_major())) {
+ for (auto index : ComposePermutations(
+ permutation, AsInt64Slice(shape.layout().minor_to_major()))) {
new_layout->add_minor_to_major(index);
}
if (shape.layout().padded_dimensions_size() > 0) {
@@ -1141,6 +1156,13 @@ Status ForEachMutableSubshapeHelper(
new_layout->add_padded_dimensions(dim);
}
}
+ // The permutation accepted by TransposeIsBitcast is the inverse of the
+ // permutation here.
+ CHECK(TransposeIsBitcast(shape, new_shape, InversePermutation(permutation)))
+ << "shape=" << HumanStringWithLayout(shape)
+ << ", new_shape=" << HumanStringWithLayout(new_shape)
+ << ", permutation={" << tensorflow::str_util::Join(permutation, ",")
+ << "}";
}
return new_shape;
}
diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h
index 5ae04451d3..17c1d7b10a 100644
--- a/tensorflow/compiler/xla/shape_util.h
+++ b/tensorflow/compiler/xla/shape_util.h
@@ -110,31 +110,33 @@ class ShapeIndex {
class ShapeIndexView {
public:
ShapeIndexView(const ShapeIndex& shape_index, int64 offset = 0)
- : ShapeIndexView(shape_index.data() + offset,
- shape_index.data() + shape_index.size()) {
+ : indices_(shape_index.data() + offset, shape_index.size() - offset) {
CHECK_LE(offset, shape_index.size());
}
- ShapeIndexView(std::initializer_list<int64> indices)
- : ShapeIndexView(indices.begin(), indices.end()) {}
+ ShapeIndexView(std::initializer_list<int64> indices) : indices_(indices) {}
ShapeIndexView(const ShapeIndexView& other) = default;
using iterator = const int64*;
- iterator begin() const { return begin_; }
- iterator end() const { return end_; }
- int64 size() const { return std::distance(begin_, end_); }
- bool empty() const { return begin_ == end_; }
+ iterator begin() const { return indices_.begin(); }
+ iterator end() const { return indices_.end(); }
+ int64 size() const { return indices_.size(); }
+ bool empty() const { return indices_.empty(); }
int64 front() const {
CHECK(!empty());
- return *begin_;
+ return indices_.front();
}
ShapeIndexView ConsumeFront() const {
- CHECK(!empty());
- auto new_begin = begin_;
- ++new_begin;
- return ShapeIndexView(new_begin, end_);
+ ShapeIndexView result = *this;
+ result.indices_.pop_front();
+ return result;
+ }
+ ShapeIndexView ConsumeBack() const {
+ ShapeIndexView result = *this;
+ result.indices_.pop_back();
+ return result;
}
- ShapeIndex ToShapeIndex() const { return ShapeIndex(begin_, end_); }
+ ShapeIndex ToShapeIndex() const { return ShapeIndex(begin(), end()); }
bool operator==(const ShapeIndexView& other) const;
bool operator!=(const ShapeIndexView& other) const;
@@ -142,10 +144,7 @@ class ShapeIndexView {
string ToString() const;
private:
- ShapeIndexView(iterator begin, iterator end) : begin_(begin), end_(end) {}
-
- iterator begin_;
- iterator end_;
+ tensorflow::gtl::ArraySlice<int64> indices_;
};
std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index);
@@ -530,7 +529,13 @@ class ShapeUtil {
static bool HasDegenerateDimensions(const Shape& shape);
// Permutes the dimensions by the given permutation, so
- // return_value.dimensions[permutation[i]] = argument.dimensions[i]
+ // return_value.dimensions[permutation[i]] = argument.dimensions[i].
+ //
+ // Postcondition: For any valid permutation,
+ //
+ // !HasLayout(shape) ||
+ // TransposeIsBitcast(shape, PermuteDimensions(permutation, shape),
+ // InversePermutation(permutation)).
static Shape PermuteDimensions(tensorflow::gtl::ArraySlice<int64> permutation,
const Shape& shape);
diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc
index b6f30af381..ed2d16c0e9 100644
--- a/tensorflow/compiler/xla/shape_util_test.cc
+++ b/tensorflow/compiler/xla/shape_util_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
+#include <numeric>
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test.h"
@@ -22,12 +23,23 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
namespace {
using ::testing::ElementsAre;
+TEST(ShapeUtilTest, ShapeIndexViewTest) {
+ ShapeIndex index = {1, 2, 3, 4};
+ ShapeIndexView index_view(index, 1);
+ EXPECT_EQ(3, index_view.size());
+ EXPECT_EQ(ShapeIndexView({2, 3, 4}), index_view);
+ EXPECT_EQ(ShapeIndexView({3, 4}), index_view.ConsumeFront());
+ EXPECT_EQ(ShapeIndexView({2, 3}), index_view.ConsumeBack());
+}
+
TEST(ShapeUtilTest, GetDimensionHelperCanNegativeIndex) {
Shape matrix = ShapeUtil::MakeShape(F32, {2, 3});
EXPECT_EQ(3, ShapeUtil::GetDimension(matrix, -1));
@@ -821,6 +833,28 @@ TEST(ShapeUtilTest, HasDegenerateDimensions) {
ShapeUtil::HasDegenerateDimensions(ShapeUtil::MakeShape(F32, {3, 0, 5})));
}
+TEST(ShapeUtilTest, PermuteDimensionsLayout) {
+ std::vector<int64> layout(3);
+ std::iota(layout.begin(), layout.end(), 0);
+ do {
+ Shape s = ShapeUtil::MakeShapeWithLayout(F32, {10, 100, 1000}, layout);
+ SCOPED_TRACE(tensorflow::strings::StrCat("s=", ShapeUtil::HumanString(s)));
+
+ std::vector<int64> permutation(3);
+ std::iota(permutation.begin(), permutation.end(), 0);
+ do {
+ SCOPED_TRACE(tensorflow::strings::StrCat(
+ "permutation=", tensorflow::str_util::Join(permutation, ",")));
+
+ // TransposeIsBitcast takes the inverse of the permutation that
+ // PermuteDimensions takes.
+ EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(
+ s, ShapeUtil::PermuteDimensions(permutation, s),
+ InversePermutation(permutation)));
+ } while (std::next_permutation(permutation.begin(), permutation.end()));
+ } while (std::next_permutation(layout.begin(), layout.end()));
+}
+
TEST(AlgebraicSimplifierTest, ReshapeIsBitcast_3x2x2_6x2_Dim0IsMostMinor) {
EXPECT_FALSE(ShapeUtil::ReshapeIsBitcast(
ShapeUtil::MakeShapeWithLayout(F32, {3, 2, 2}, {0, 1, 2}),
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 02f6fc3a27..6a75aa6794 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -65,6 +65,7 @@ cc_library(
srcs = ["test_utils.cc"],
hdrs = ["test_utils.h"],
deps = [
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
@@ -88,6 +89,7 @@ cc_library(
"//tensorflow/compiler/xla:array3d",
"//tensorflow/compiler/xla:array4d",
"//tensorflow/compiler/xla:error_spec",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_comparison",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:test",
@@ -179,6 +181,7 @@ cc_library(
"//tensorflow/compiler/xla:array3d",
"//tensorflow/compiler/xla:array4d",
"//tensorflow/compiler/xla:execution_options_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@@ -209,6 +212,7 @@ cc_library(
deps = [
":codegen_test_base",
":filecheck",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:llvm_compiler",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:test",
@@ -302,7 +306,7 @@ xla_test(
"enable_for_xla_interpreter",
],
deps = [
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
@@ -345,7 +349,7 @@ xla_test(
"enable_for_xla_interpreter",
],
deps = [
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
@@ -406,7 +410,7 @@ xla_test(
tags = ["enable_for_xla_interpreter"],
deps = [
"//tensorflow/compiler/xla:array2d",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
@@ -435,7 +439,7 @@ xla_test(
tags = ["optonly"],
deps = [
"//tensorflow/compiler/xla:array2d",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:xla_data_proto",
@@ -531,6 +535,7 @@ xla_test(
srcs = ["scalar_computations_test.cc"],
shard_count = 32,
deps = [
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
@@ -573,7 +578,7 @@ xla_test(
"enable_for_xla_interpreter",
],
deps = [
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
@@ -599,7 +604,7 @@ xla_test(
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array3d",
"//tensorflow/compiler/xla:array4d",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
@@ -645,7 +650,7 @@ xla_test(
tags = ["enable_for_xla_interpreter"],
deps = [
"//tensorflow/compiler/xla:array2d",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
@@ -764,6 +769,7 @@ xla_test(
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array3d",
"//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:local_client",
@@ -780,7 +786,7 @@ xla_test(
CONVOLUTION_TEST_DEPS = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array4d",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:reference_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
@@ -827,7 +833,7 @@ xla_test(
deps = [
"//tensorflow/compiler/xla:array3d",
"//tensorflow/compiler/xla:array4d",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:reference_util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:local_client",
@@ -874,7 +880,7 @@ xla_test(
":test_utils",
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array4d",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:reference_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@@ -907,7 +913,7 @@ xla_test(
":test_utils",
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array4d",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:reference_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@@ -940,7 +946,7 @@ xla_test(
],
deps = [
":test_utils",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
@@ -1031,6 +1037,7 @@ xla_test(
],
deps = [
"//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
@@ -1079,6 +1086,7 @@ xla_test(
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:reference_util",
"//tensorflow/compiler/xla:shape_util",
@@ -1149,7 +1157,7 @@ xla_test(
],
deps = [
"//tensorflow/compiler/xla:array2d",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:reference_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@@ -1176,7 +1184,7 @@ xla_test(
deps = [
":client_library_test_base",
"//tensorflow/compiler/xla:array2d",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
@@ -1228,6 +1236,7 @@ xla_test(
"enable_for_xla_interpreter",
],
deps = [
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test_helpers",
@@ -1246,6 +1255,7 @@ xla_test(
name = "custom_call_test",
srcs = ["custom_call_test.cc"],
deps = [
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
@@ -1291,6 +1301,7 @@ xla_test(
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
@@ -1368,7 +1379,7 @@ xla_test(
],
deps = [
"//tensorflow/compiler/xla:array2d",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:reference_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
@@ -1391,7 +1402,7 @@ xla_test(
name = "prng_test",
srcs = ["prng_test.cc"],
deps = [
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:util",
@@ -1416,6 +1427,7 @@ xla_test(
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:reference_util",
"//tensorflow/compiler/xla:shape_util",
@@ -1530,7 +1542,7 @@ xla_test(
name = "cross_replica_sum_test",
srcs = ["cross_replica_sum_test.cc"],
deps = [
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
@@ -1574,7 +1586,7 @@ xla_test(
name = "compilation_cache_test",
srcs = ["compilation_cache_test.cc"],
deps = [
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:xla_data_proto",
@@ -1614,7 +1626,7 @@ xla_test(
name = "compute_constant_test",
srcs = ["compute_constant_test.cc"],
deps = [
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
@@ -1689,7 +1701,7 @@ xla_test(
"enable_for_xla_interpreter",
],
deps = [
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:protobuf_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
@@ -1714,7 +1726,7 @@ xla_test(
"enable_for_xla_interpreter",
],
deps = [
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
@@ -1731,6 +1743,7 @@ tf_cc_test(
srcs = ["llvm_compiler_test.cc"],
tags = ["requires-gpu-sm35"],
deps = [
+ "//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla/service:backend",
"//tensorflow/compiler/xla/service:cpu_plugin",
@@ -1751,7 +1764,7 @@ xla_test(
name = "round_trip_packed_literal_test",
srcs = ["round_trip_packed_literal_test.cc"],
deps = [
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:packed_literal_reader",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
@@ -1774,7 +1787,7 @@ xla_test(
],
deps = [
"//tensorflow/compiler/xla:array2d",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
@@ -1802,7 +1815,7 @@ xla_test(
srcs = ["multioutput_fusion_test.cc"],
deps = [
"//tensorflow/compiler/xla:array2d",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
@@ -1842,7 +1855,7 @@ xla_test(
name = "local_client_allocation_test",
srcs = ["local_client_allocation_test.cc"],
deps = [
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
@@ -1865,7 +1878,7 @@ xla_test(
shard_count = 30,
tags = ["optonly"],
deps = [
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
@@ -1911,7 +1924,7 @@ xla_test(
srcs = ["round_trip_transfer_test.cc"],
deps = [
"//tensorflow/compiler/xla:array4d",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:xla_data_proto",
@@ -1932,7 +1945,7 @@ xla_test(
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array4d",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:reference_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@@ -1980,7 +1993,7 @@ xla_test(
":literal_test_util",
":local_client_test_base",
":xla_internal_test_main",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
index 3bdf98544a..3ae96fa1bc 100644
--- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
@@ -26,7 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
@@ -225,7 +225,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) {
0x8000000000000000LL,
0x8000000000000000LL,
1};
- std::unique_ptr<Literal> lhs_literal = Literal::CreateR1<uint64>({lhs});
+ std::unique_ptr<Literal> lhs_literal = LiteralUtil::CreateR1<uint64>({lhs});
auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param");
std::unique_ptr<GlobalData> lhs_data =
client_->TransferToServer(*lhs_literal).ConsumeValueOrDie();
@@ -239,7 +239,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) {
0,
1,
0x8000000000000000LL};
- std::unique_ptr<Literal> rhs_literal = Literal::CreateR1<uint64>({rhs});
+ std::unique_ptr<Literal> rhs_literal = LiteralUtil::CreateR1<uint64>({rhs});
auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param");
std::unique_ptr<GlobalData> rhs_data =
client_->TransferToServer(*rhs_literal).ConsumeValueOrDie();
@@ -265,7 +265,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) {
1,
0,
-1};
- std::unique_ptr<Literal> lhs_literal = Literal::CreateR1<int64>({lhs});
+ std::unique_ptr<Literal> lhs_literal = LiteralUtil::CreateR1<int64>({lhs});
auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param");
std::unique_ptr<GlobalData> lhs_data =
client_->TransferToServer(*lhs_literal).ConsumeValueOrDie();
@@ -278,7 +278,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) {
0x7FFFFFFFFFFFFFFLL,
0x7FFFFFFFFFFFFFFFLL,
0x7FFFFFFFFFFFFFFFLL};
- std::unique_ptr<Literal> rhs_literal = Literal::CreateR1<int64>({rhs});
+ std::unique_ptr<Literal> rhs_literal = LiteralUtil::CreateR1<int64>({rhs});
auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param");
std::unique_ptr<GlobalData> rhs_data =
client_->TransferToServer(*rhs_literal).ConsumeValueOrDie();
@@ -303,13 +303,13 @@ TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) {
b_values.push_back(2 * i / static_cast<float>(count + 2));
}
- std::unique_ptr<Literal> a_literal = Literal::CreateR1<float>({a_values});
+ std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR1<float>({a_values});
std::unique_ptr<GlobalData> a_data =
client_->TransferToServer(*a_literal).ConsumeValueOrDie();
auto a_constant = ConstantR1<float>(&builder, a_values);
auto a_param = Parameter(&builder, 0, a_literal->shape(), "a_param");
- std::unique_ptr<Literal> b_literal = Literal::CreateR1<float>({b_values});
+ std::unique_ptr<Literal> b_literal = LiteralUtil::CreateR1<float>({b_values});
std::unique_ptr<GlobalData> b_data =
client_->TransferToServer(*b_literal).ConsumeValueOrDie();
auto b_constant = Parameter(&builder, 1, a_literal->shape(), "b_param");
@@ -1426,7 +1426,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowSpecialF32) {
std::vector<float> values = {1.0f, 2.0f, 3.2f, -4.0f};
std::vector<float> exponents = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
- std::unique_ptr<Literal> param_literal = Literal::CreateR1<float>(values);
+ std::unique_ptr<Literal> param_literal = LiteralUtil::CreateR1<float>(values);
std::unique_ptr<GlobalData> param_data =
client_->TransferToServer(*param_literal).ConsumeValueOrDie();
@@ -1454,10 +1454,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowOfExpF32) {
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
- std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>(values0);
+ std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
client_->TransferToServer(*literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>(values1);
+ std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
client_->TransferToServer(*literal1).ConsumeValueOrDie();
auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
@@ -1479,10 +1479,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogOfPowerF32) {
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, 4.0f, 0.5f, 5.7f};
std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
- std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>(values0);
+ std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
client_->TransferToServer(*literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>(values1);
+ std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
client_->TransferToServer(*literal1).ConsumeValueOrDie();
auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
@@ -1504,10 +1504,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulOfExpF32) {
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
- std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>(values0);
+ std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
client_->TransferToServer(*literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>(values1);
+ std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
client_->TransferToServer(*literal1).ConsumeValueOrDie();
auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
@@ -1529,10 +1529,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivOfExpF32) {
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
- std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>(values0);
+ std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
client_->TransferToServer(*literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>(values1);
+ std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
client_->TransferToServer(*literal1).ConsumeValueOrDie();
auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
@@ -1555,15 +1555,15 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div3_lhs_F32) {
std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
- std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>(values0);
+ std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
client_->TransferToServer(*literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>(values1);
+ std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
client_->TransferToServer(*literal1).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal2 = Literal::CreateR1<float>(values2);
+ std::unique_ptr<Literal> literal2 = LiteralUtil::CreateR1<float>(values2);
std::unique_ptr<GlobalData> data2 =
client_->TransferToServer(*literal2).ConsumeValueOrDie();
auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
@@ -1587,15 +1587,15 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div3_rhs_F32) {
std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
- std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>(values0);
+ std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
client_->TransferToServer(*literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>(values1);
+ std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
client_->TransferToServer(*literal1).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal2 = Literal::CreateR1<float>(values2);
+ std::unique_ptr<Literal> literal2 = LiteralUtil::CreateR1<float>(values2);
std::unique_ptr<GlobalData> data2 =
client_->TransferToServer(*literal2).ConsumeValueOrDie();
@@ -1620,15 +1620,15 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivOfPowerF32) {
std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, 1.0f, 0.5f};
std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 9.5f, -11.0f, -0.5f};
- std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>(values0);
+ std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
client_->TransferToServer(*literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>(values1);
+ std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
client_->TransferToServer(*literal1).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal2 = Literal::CreateR1<float>(values2);
+ std::unique_ptr<Literal> literal2 = LiteralUtil::CreateR1<float>(values2);
std::unique_ptr<GlobalData> data2 =
client_->TransferToServer(*literal2).ConsumeValueOrDie();
@@ -1654,19 +1654,19 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div4F32) {
std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
std::vector<float> values3 = {2.1f, 3.1f, 9.9f, -4.5f, -11.0f, -21.5f};
- std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>(values0);
+ std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
client_->TransferToServer(*literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>(values1);
+ std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
client_->TransferToServer(*literal1).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal2 = Literal::CreateR1<float>(values2);
+ std::unique_ptr<Literal> literal2 = LiteralUtil::CreateR1<float>(values2);
std::unique_ptr<GlobalData> data2 =
client_->TransferToServer(*literal2).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal3 = Literal::CreateR1<float>(values3);
+ std::unique_ptr<Literal> literal3 = LiteralUtil::CreateR1<float>(values3);
std::unique_ptr<GlobalData> data3 =
client_->TransferToServer(*literal3).ConsumeValueOrDie();
@@ -2101,12 +2101,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) {
XlaBuilder builder(TestName());
std::unique_ptr<Literal> param0_literal =
- Literal::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
+ LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
std::unique_ptr<Literal> param1_literal =
- Literal::CreateR1<float>({7.2f, 2.3f, 3.4f, 5.6f});
+ LiteralUtil::CreateR1<float>({7.2f, 2.3f, 3.4f, 5.6f});
std::unique_ptr<GlobalData> param1_data =
client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
@@ -2123,12 +2123,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) {
XlaBuilder builder(TestName());
std::unique_ptr<Literal> param0_literal =
- Literal::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0));
+ LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0));
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
std::unique_ptr<Literal> param1_literal =
- Literal::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0));
+ LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0));
std::unique_ptr<GlobalData> param1_data =
client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
@@ -2145,7 +2145,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) {
XlaBuilder builder(TestName());
std::unique_ptr<Literal> param0_literal =
- Literal::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
+ LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
@@ -2201,7 +2201,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, TanhF32sVector) {
// the input tensor is large enough to exercise the vectorized tanh
// implementation on XLA CPU.
XlaBuilder builder(TestName());
- auto input_literal = Literal::CreateR1<float>(
+ auto input_literal = LiteralUtil::CreateR1<float>(
{1.02, -0.32, 0.85, 0.90, 1.23, -0.91, -0.49, 0.80, -0.67, 0.16,
-0.07, 0.39, -0.41, 0.04, 1.36, 1.25, 0.41, 0.65, -1.08, 0.32,
-1.45, -0.77, -1.09, 0.91, -1.03, -0.30, -1.11, -1.17, 1.50, -0.85,
@@ -2243,7 +2243,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) {
// Just to help make sense of the scales here -- exp(89) saturates float32 and
// exp(-10) is smaller than our error spec.
- std::unique_ptr<Literal> input_literal = Literal::CreateR1<float>(
+ std::unique_ptr<Literal> input_literal = LiteralUtil::CreateR1<float>(
{1.02, -0.32, 0.85, 0.9, 1.23, -0.91, -0.49, 0.8, -1.31,
-1.44, -0.13, -1.31, -0.79, 1.41, 1.21, 1.05, -195.6, -194.5,
-193.4, -192.3, -191.2, -190.1, -189.0, -187.9, -19.6, -18.5, -17.4,
@@ -2277,7 +2277,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) {
// implementation on XLA CPU.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> input_literal = Literal::CreateR1<float>(
+ std::unique_ptr<Literal> input_literal = LiteralUtil::CreateR1<float>(
{-1.29, -1.41, -1.25, -13.5, -11.7, -17.9, -198,
-167, 1.29, 1.41, 1.25, 13.5, 11.7, 17.9,
198, 167, 1.27e+03, 1.33e+03, 1.74e+03, 1.6e+04, 1.84e+04,
@@ -2469,9 +2469,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Eq) {
auto cmp_dim_1 = Eq(v, m, /*broadcast_dimensions=*/{0});
Tuple(&builder, {cmp_dim_0, cmp_dim_1});
- auto expected = Literal::MakeTuple(
- {Literal::CreateR2<bool>({{true, true}, {true, false}}).get(),
- Literal::CreateR2<bool>({{true, false}, {false, false}}).get()});
+ auto expected = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR2<bool>({{true, true}, {true, false}}).get(),
+ LiteralUtil::CreateR2<bool>({{true, false}, {false, false}}).get()});
ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
}
@@ -2825,8 +2825,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) {
std::iota(r1.begin(), r1.end(), 1.0);
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> a_literal = Literal::CreateR4FromArray4DWithLayout(
- r4, LayoutUtil::MakeLayout({0, 1, 2, 3}));
+ std::unique_ptr<Literal> a_literal =
+ LiteralUtil::CreateR4FromArray4DWithLayout(
+ r4, LayoutUtil::MakeLayout({0, 1, 2, 3}));
auto a = ConstantLiteral(&builder, *a_literal);
auto b = ConstantR1<float>(&builder, r1);
Add(a, b, {1});
@@ -2887,8 +2888,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, NonIdentityBroadcastOfSameRankIsDisallowed) {
// broadcast.
XLA_TEST_F(ArrayElementwiseOpTest, ImplictBroadcastInFusedExpressions) {
XlaBuilder builder(TestName());
- auto x_literal = Literal::CreateR1<float>({1, 2, 3});
- auto y_literal = Literal::CreateR1<float>({4, 5});
+ auto x_literal = LiteralUtil::CreateR1<float>({1, 2, 3});
+ auto y_literal = LiteralUtil::CreateR1<float>({4, 5});
auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc
index 217673c8cb..6a024798f9 100644
--- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc
+++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc
@@ -24,7 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -63,7 +63,7 @@ class BatchNormalizationTest
{5.0f, 4.4f}, // p2
});
input_array_.FillWithPZ(pz);
- input_literal_ = std::move(*Literal::CreateR4FromArray4D(input_array_));
+ input_literal_ = std::move(*LiteralUtil::CreateR4FromArray4D(input_array_));
CHECK_EQ(kSamples, input_array_.planes());
CHECK_EQ(kZ, input_array_.depth());
CHECK_EQ(kY, input_array_.height());
@@ -242,12 +242,12 @@ XLA_TEST_P(BatchNormalizationTest, BasicTraining) {
BatchNormTraining(operand, scale, offset,
/*epsilon=*/0.001, kFeatureIndex);
- auto expected = Literal::MakeTuple(
- {Literal::CreateR4<float>({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}},
- {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}})
+ auto expected = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR4<float>({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}},
+ {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}})
.get(),
- Literal::CreateR1<float>({4, 5}).get(),
- Literal::CreateR1<float>({5, 5}).get()});
+ LiteralUtil::CreateR1<float>({4, 5}).get(),
+ LiteralUtil::CreateR1<float>({5, 5}).get()});
ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
}
@@ -267,12 +267,12 @@ XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnDimension2) {
BatchNormTraining(operand, scale, offset,
/*epsilon=*/0.001, kFeatureIndex);
- auto expected = Literal::MakeTuple(
- {Literal::CreateR4<float>({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}},
- {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}})
+ auto expected = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR4<float>({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}},
+ {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}})
.get(),
- Literal::CreateR1<float>({4, 5}).get(),
- Literal::CreateR1<float>({5, 5}).get()});
+ LiteralUtil::CreateR1<float>({4, 5}).get(),
+ LiteralUtil::CreateR1<float>({5, 5}).get()});
ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
}
@@ -298,11 +298,11 @@ XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) {
BatchNormTraining(h0, h1, h2,
/*epsilon=*/1, kFeatureIndex);
- auto expected = Literal::MakeTuple(
- {Literal::CreateR3FromArray3D<float>(Array3D<float>(260, 2, 2, 1.0f))
+ auto expected = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(260, 2, 2, 1.0f))
.get(),
- Literal::CreateR1<float>(std::vector<float>(260, 1.0f)).get(),
- Literal::CreateR1<float>(std::vector<float>(260, 0.0f)).get()});
+ LiteralUtil::CreateR1<float>(std::vector<float>(260, 1.0f)).get(),
+ LiteralUtil::CreateR1<float>(std::vector<float>(260, 0.0f)).get()});
ComputeAndCompareTuple(&builder, *expected,
{operand.get(), scale.get(), offset.get()},
@@ -331,11 +331,12 @@ XLA_TEST_P(BatchNormalizationTest, LargeEpsilonTest) {
BatchNormTraining(h0, h1, h2,
/*epsilon=*/-100, kFeatureIndex);
- auto expected = Literal::MakeTuple(
- {Literal::CreateR3FromArray3D<float>({{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}})
+ auto expected = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR3FromArray3D<float>(
+ {{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}})
.get(),
- Literal::CreateR1<float>(std::vector<float>(1, 15.0f)).get(),
- Literal::CreateR1<float>(std::vector<float>(1, 125.0f)).get()});
+ LiteralUtil::CreateR1<float>(std::vector<float>(1, 15.0f)).get(),
+ LiteralUtil::CreateR1<float>(std::vector<float>(1, 125.0f)).get()});
ComputeAndCompareTuple(&builder, *expected,
{operand.get(), scale.get(), offset.get()},
@@ -362,12 +363,12 @@ XLA_TEST_P(BatchNormalizationTest, BatchNormGradBasic) {
BatchNormGrad(operand, scale, mean, var, grad_output,
/*epsilon=*/0.0, kFeatureIndex);
- auto expected = Literal::MakeTuple(
- {Literal::CreateR4<float>({{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}},
- {{{1.f}, {1.f}}, {{3.f}, {3.f}}}})
+ auto expected = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR4<float>({{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}},
+ {{{1.f}, {1.f}}, {{3.f}, {3.f}}}})
.get(),
- Literal::CreateR1<float>({0, 0}).get(),
- Literal::CreateR1<float>({16, 20}).get()});
+ LiteralUtil::CreateR1<float>({0, 0}).get(),
+ LiteralUtil::CreateR1<float>({16, 20}).get()});
ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
}
@@ -513,11 +514,12 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) {
auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean4D, var4D,
scale4D, offset4D, epsilon);
- auto expected_normalized = Literal::CreateR4FromArray4D<float>(normalized);
+ auto expected_normalized =
+ LiteralUtil::CreateR4FromArray4D<float>(normalized);
- auto offset_literal = Literal::CreateR1<float>(offset);
- auto scale_literal = Literal::CreateR1<float>(scale);
- auto input_literal = Literal::CreateR4FromArray4D<float>(input_array);
+ auto offset_literal = LiteralUtil::CreateR1<float>(offset);
+ auto scale_literal = LiteralUtil::CreateR1<float>(scale);
+ auto input_literal = LiteralUtil::CreateR4FromArray4D<float>(input_array);
auto input_activations =
Parameter(&builder, 0, input_literal->shape(), "input");
@@ -526,9 +528,9 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) {
auto offset_activations =
Parameter(&builder, 2, offset_literal->shape(), "scale");
- auto expected = Literal::MakeTuple({expected_normalized.get(),
- Literal::CreateR1<float>(mean).get(),
- Literal::CreateR1<float>(var).get()});
+ auto expected = LiteralUtil::MakeTuple(
+ {expected_normalized.get(), LiteralUtil::CreateR1<float>(mean).get(),
+ LiteralUtil::CreateR1<float>(var).get()});
std::unique_ptr<GlobalData> input_data =
client_->TransferToServer(*input_literal).ConsumeValueOrDie();
@@ -613,11 +615,11 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedInferencingTests) {
auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean4D, var4D,
scale4D, offset4D, epsilon);
- auto offset_literal = Literal::CreateR1<float>(offset);
- auto scale_literal = Literal::CreateR1<float>(scale);
- auto mean_literal = Literal::CreateR1<float>(mean);
- auto var_literal = Literal::CreateR1<float>(var);
- auto input_literal = Literal::CreateR4FromArray4D<float>(input_array);
+ auto offset_literal = LiteralUtil::CreateR1<float>(offset);
+ auto scale_literal = LiteralUtil::CreateR1<float>(scale);
+ auto mean_literal = LiteralUtil::CreateR1<float>(mean);
+ auto var_literal = LiteralUtil::CreateR1<float>(var);
+ auto input_literal = LiteralUtil::CreateR4FromArray4D<float>(input_array);
auto input_activations =
Parameter(&builder, 0, input_literal->shape(), "input");
@@ -800,14 +802,14 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) {
});
auto expected_grad_activation =
- Literal::CreateR4FromArray4D<float>(grad_activation);
+ LiteralUtil::CreateR4FromArray4D<float>(grad_activation);
- auto input_literal = Literal::CreateR4FromArray4D<float>(input_array);
- auto scale_literal = Literal::CreateR1<float>(scale);
- auto mean_literal = Literal::CreateR1<float>(mean);
- auto var_literal = Literal::CreateR1<float>(var);
+ auto input_literal = LiteralUtil::CreateR4FromArray4D<float>(input_array);
+ auto scale_literal = LiteralUtil::CreateR1<float>(scale);
+ auto mean_literal = LiteralUtil::CreateR1<float>(mean);
+ auto var_literal = LiteralUtil::CreateR1<float>(var);
auto grad_output_literal =
- Literal::CreateR4FromArray4D<float>(grad_output_array);
+ LiteralUtil::CreateR4FromArray4D<float>(grad_output_array);
auto input_parameter =
Parameter(&builder, 0, input_literal->shape(), "input");
@@ -833,9 +835,9 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) {
grad_output_parameter, epsilon, feature_index);
auto expected =
- Literal::MakeTuple({expected_grad_activation.get(),
- Literal::CreateR1<float>(grad_scale).get(),
- Literal::CreateR1<float>(grad_offset).get()});
+ LiteralUtil::MakeTuple({expected_grad_activation.get(),
+ LiteralUtil::CreateR1<float>(grad_scale).get(),
+ LiteralUtil::CreateR1<float>(grad_offset).get()});
// Run all HLO passes during this test. In particular, ClientLibraryTestBase
// disables constant folding, but we want it enabled for our zero-sized tensor
diff --git a/tensorflow/compiler/xla/tests/bfloat16_test.cc b/tensorflow/compiler/xla/tests/bfloat16_test.cc
index f40d03bea7..747c82b502 100644
--- a/tensorflow/compiler/xla/tests/bfloat16_test.cc
+++ b/tensorflow/compiler/xla/tests/bfloat16_test.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -95,18 +95,18 @@ XLA_TEST_F(Bfloat16Test, BatchNormTraining) {
BatchNormTraining(operand, scale, offset, /*epsilon=*/0.001, kFeatureIndex);
- auto expected = Literal::MakeTuple(
- {Literal::CreateR4<bfloat16>(
+ auto expected = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR4<bfloat16>(
{{{{static_cast<bfloat16>(-1.6875f)},
{static_cast<bfloat16>(-2.04f)}},
{{static_cast<bfloat16>(0.105f)}, {static_cast<bfloat16>(0.66f)}}},
{{{static_cast<bfloat16>(1.89f)}, {static_cast<bfloat16>(3.35f)}},
{{static_cast<bfloat16>(3.7f)}, {static_cast<bfloat16>(6.04f)}}}})
.get(),
- Literal::CreateR1<bfloat16>(
+ LiteralUtil::CreateR1<bfloat16>(
{static_cast<bfloat16>(4), static_cast<bfloat16>(5)})
.get(),
- Literal::CreateR1<bfloat16>(
+ LiteralUtil::CreateR1<bfloat16>(
{static_cast<bfloat16>(5), static_cast<bfloat16>(5)})
.get()});
@@ -139,17 +139,17 @@ XLA_TEST_F(Bfloat16Test, BatchNormGrad) {
BatchNormGrad(operand, scale, mean, var, grad_output,
/*epsilon=*/0.0, kFeatureIndex);
- auto expected = Literal::MakeTuple(
- {Literal::CreateR4<bfloat16>(
+ auto expected = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR4<bfloat16>(
{{{{static_cast<bfloat16>(-3.f)}, {static_cast<bfloat16>(-3.f)}},
{{static_cast<bfloat16>(-1.f)}, {static_cast<bfloat16>(-1.f)}}},
{{{static_cast<bfloat16>(1.f)}, {static_cast<bfloat16>(1.f)}},
{{static_cast<bfloat16>(3.f)}, {static_cast<bfloat16>(3.f)}}}})
.get(),
- Literal::CreateR1<bfloat16>(
+ LiteralUtil::CreateR1<bfloat16>(
{static_cast<bfloat16>(0), static_cast<bfloat16>(0)})
.get(),
- Literal::CreateR1<bfloat16>(
+ LiteralUtil::CreateR1<bfloat16>(
{static_cast<bfloat16>(16), static_cast<bfloat16>(20)})
.get()});
diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
index 91aba9a8de..50dd574624 100644
--- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
+++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
@@ -58,7 +59,7 @@ class BroadcastSimpleTest : public ClientLibraryTestBase {
Array3D<float>* r3_array, float start, float end, int seed) {
*r3_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major);
r3_array->FillRandom(start, end, seed);
- auto r3_data = Literal::CreateR3FromArray3D(*r3_array)->Relayout(
+ auto r3_data = LiteralUtil::CreateR3FromArray3D(*r3_array)->Relayout(
LayoutUtil::MakeLayout(minor_to_major));
std::unique_ptr<GlobalData> r3_global_data =
client_->TransferToServer(*r3_data).ConsumeValueOrDie();
@@ -71,7 +72,7 @@ class BroadcastSimpleTest : public ClientLibraryTestBase {
Array2D<float>* r2_array, float start, float end, int seed) {
*r2_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major);
r2_array->FillRandom(start, end, seed);
- auto r2_data = Literal::CreateR2FromArray2D(*r2_array)->Relayout(
+ auto r2_data = LiteralUtil::CreateR2FromArray2D(*r2_array)->Relayout(
LayoutUtil::MakeLayout(minor_to_major));
std::unique_ptr<GlobalData> r2_global_data =
client_->TransferToServer(*r2_data).ConsumeValueOrDie();
@@ -290,13 +291,13 @@ XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) {
XlaBuilder b(TestName());
Add(ConstantR2<float>(&b, {{1.0, 5.0}}),
- ConstantLiteral(&b, *Literal::CreateR3<float>(
+ ConstantLiteral(&b, *LiteralUtil::CreateR3<float>(
{{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})),
/*broadcast_dimensions=*/{1, 2});
auto expected =
- Literal::CreateR3<float>({{{3.0, 7.0}, {4.0, 8.0}, {5.0, 9.0}},
- {{6.0, 10.0}, {7.0, 11.0}, {8.0, 12.0}}});
+ LiteralUtil::CreateR3<float>({{{3.0, 7.0}, {4.0, 8.0}, {5.0, 9.0}},
+ {{6.0, 10.0}, {7.0, 11.0}, {8.0, 12.0}}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
}
@@ -365,7 +366,7 @@ XLA_TEST_P(BroadcastR3ImplicitTest, Doit) {
}
}
}
- auto expected = Literal::CreateR3FromArray3D(expected_array);
+ auto expected = LiteralUtil::CreateR3FromArray3D(expected_array);
ComputeAndCompareLiteral(
&builder, *expected,
{r3_implicit_global_data.get(), r3_global_data.get()},
@@ -390,7 +391,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) {
Add(r3h, r1h);
auto expected =
- Literal::CreateR3<float>({{{2, 3}, {4, 5}}, {{7, 8}, {9, 10}}});
+ LiteralUtil::CreateR3<float>({{{2, 3}, {4, 5}}, {{7, 8}, {9, 10}}});
ComputeAndCompareLiteral(&b, *expected, {r3.get(), r1.get()},
ErrorSpec(0.0001));
@@ -398,39 +399,40 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) {
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1) {
XlaBuilder b(TestName());
- auto r1 = ConstantLiteral(&b, *Literal::CreateR3<float>({{{1, 2}}}));
+ auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1, 2}}}));
auto r3 = ConstantLiteral(
- &b, *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1);
auto expected =
- Literal::CreateR3<float>({{{2, 4}, {4, 6}}, {{6, 8}, {8, 10}}});
+ LiteralUtil::CreateR3<float>({{{2, 4}, {4, 6}}, {{6, 8}, {8, 10}}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_2) {
XlaBuilder b(TestName());
- auto r1 = ConstantLiteral(&b, *Literal::CreateR3<float>({{{1}, {2}}}));
+ auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1}, {2}}}));
auto r3 = ConstantLiteral(
- &b, *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1);
auto expected =
- Literal::CreateR3<float>({{{2, 3}, {5, 6}}, {{6, 7}, {9, 10}}});
+ LiteralUtil::CreateR3<float>({{{2, 3}, {5, 6}}, {{6, 7}, {9, 10}}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) {
XlaBuilder b(TestName());
- auto r1 = ConstantLiteral(&b, *Literal::CreateR3<float>({{{1, 2}, {3, 4}}}));
+ auto r1 =
+ ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}}));
auto r3 = ConstantLiteral(
- &b, *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1);
auto expected =
- Literal::CreateR3<float>({{{2, 4}, {6, 8}}, {{6, 8}, {10, 12}}});
+ LiteralUtil::CreateR3<float>({{{2, 4}, {6, 8}}, {{6, 8}, {10, 12}}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
}
@@ -438,40 +440,40 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) {
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1) {
XlaBuilder b(TestName());
auto r1 =
- ConstantLiteral(&b, *Literal::CreateR3<float>({{{1, 2}}, {{3, 4}}}));
+ ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1, 2}}, {{3, 4}}}));
auto r3 = ConstantLiteral(
- &b, *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1);
auto expected =
- Literal::CreateR3<float>({{{2, 4}, {4, 6}}, {{8, 10}, {10, 12}}});
+ LiteralUtil::CreateR3<float>({{{2, 4}, {4, 6}}, {{8, 10}, {10, 12}}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_2) {
XlaBuilder b(TestName());
- auto r1 =
- ConstantLiteral(&b, *Literal::CreateR3<float>({{{1}, {2}}, {{3}, {4}}}));
+ auto r1 = ConstantLiteral(
+ &b, *LiteralUtil::CreateR3<float>({{{1}, {2}}, {{3}, {4}}}));
auto r3 = ConstantLiteral(
- &b, *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1);
auto expected =
- Literal::CreateR3<float>({{{2, 3}, {5, 6}}, {{8, 9}, {11, 12}}});
+ LiteralUtil::CreateR3<float>({{{2, 3}, {5, 6}}, {{8, 9}, {11, 12}}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1_2) {
XlaBuilder b(TestName());
- auto r1 = ConstantLiteral(&b, *Literal::CreateR3<float>({{{1}}}));
+ auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1}}}));
auto r3 = ConstantLiteral(
- &b, *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1);
auto expected =
- Literal::CreateR3<float>({{{2, 3}, {4, 5}}, {{6, 7}, {8, 9}}});
+ LiteralUtil::CreateR3<float>({{{2, 3}, {4, 5}}, {{6, 7}, {8, 9}}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
}
@@ -612,7 +614,7 @@ XLA_TEST_P(BroadcastR2ImplicitTest, Doit) {
*v = ApplyOpToFloats(spec.op2, tmp, v3);
});
- auto expected = Literal::CreateR2FromArray2D(expected_array);
+ auto expected = LiteralUtil::CreateR2FromArray2D(expected_array);
ComputeAndCompareLiteral(
&builder, *expected,
{r2_implicit_global_data1.get(), r2_global_data.get(),
@@ -626,22 +628,24 @@ INSTANTIATE_TEST_CASE_P(BroadcastR2ImplicitTestInstances,
XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_0) {
XlaBuilder b(TestName());
- auto r1 = ConstantLiteral(&b, *Literal::CreateR2<float>({{1, 2}}));
- auto r2 = ConstantLiteral(&b, *Literal::CreateR2<float>({{1, 2}, {3, 4}}));
+ auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1, 2}}));
+ auto r2 =
+ ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
Add(r2, r1);
- auto expected = Literal::CreateR2<float>({{2, 4}, {4, 6}});
+ auto expected = LiteralUtil::CreateR2<float>({{2, 4}, {4, 6}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_1) {
XlaBuilder b(TestName());
- auto r1 = ConstantLiteral(&b, *Literal::CreateR2<float>({{1}, {2}}));
- auto r2 = ConstantLiteral(&b, *Literal::CreateR2<float>({{1, 2}, {3, 4}}));
+ auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1}, {2}}));
+ auto r2 =
+ ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
Add(r2, r1);
- auto expected = Literal::CreateR2<float>({{2, 3}, {5, 6}});
+ auto expected = LiteralUtil::CreateR2<float>({{2, 3}, {5, 6}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
}
@@ -650,11 +654,11 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim0) {
XlaBuilder b(TestName());
auto r1 = ConstantR1<float>(&b, {10, 20});
auto r3 = ConstantLiteral(
- &b, *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1, {0});
- auto expected =
- Literal::CreateR3<float>({{{11, 12}, {13, 14}}, {{25, 26}, {27, 28}}});
+ auto expected = LiteralUtil::CreateR3<float>(
+ {{{11, 12}, {13, 14}}, {{25, 26}, {27, 28}}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
}
@@ -663,11 +667,11 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim1) {
XlaBuilder b(TestName());
auto r1 = ConstantR1<float>(&b, {10, 20});
auto r3 = ConstantLiteral(
- &b, *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r1, r3, {1});
- auto expected =
- Literal::CreateR3<float>({{{11, 12}, {23, 24}}, {{15, 16}, {27, 28}}});
+ auto expected = LiteralUtil::CreateR3<float>(
+ {{{11, 12}, {23, 24}}, {{15, 16}, {27, 28}}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
}
@@ -676,11 +680,11 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim2) {
XlaBuilder b(TestName());
auto r1 = ConstantR1<float>(&b, {10, 20});
auto r3 = ConstantLiteral(
- &b, *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r1, r3, {2});
- auto expected =
- Literal::CreateR3<float>({{{11, 22}, {13, 24}}, {{15, 26}, {17, 28}}});
+ auto expected = LiteralUtil::CreateR3<float>(
+ {{{11, 22}, {13, 24}}, {{15, 26}, {17, 28}}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
}
@@ -691,7 +695,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) {
auto r1_1 = ConstantR1<float>(&b, {100, 200});
auto r1_2 = ConstantR1<float>(&b, {10, 20});
auto r3 = ConstantLiteral(
- &b, *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
for (int i = 0; i < 3; ++i) {
r3 = Add(r1_0, r3, {0});
r3 = Add(r3, r1_1, {1});
@@ -699,7 +703,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) {
}
r3 = Mul(r3, ConstantR0<float>(&b, -2));
- auto expected = Literal::CreateR3<float>(
+ auto expected = LiteralUtil::CreateR3<float>(
{{{-6 * 1110 - 2, -6 * 1120 - 4}, {-6 * 1210 - 6, -6 * 1220 - 8}},
{{-6 * 2110 - 10, -6 * 2120 - 12}, {-6 * 2210 - 14, -6 * 2220 - 16}}});
@@ -720,7 +724,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) {
}
r3 = Mul(r3, ConstantR0<float>(&b, -1));
- auto expected = Literal::CreateR3<float>(
+ auto expected = LiteralUtil::CreateR3<float>(
{{{-3 * 1110 - 3, -3 * 1120 - 3}, {-3 * 1210 - 3, -3 * 1220 - 3}},
{{-3 * 2110 - 3, -3 * 2120 - 3}, {-3 * 2210 - 3, -3 * 2220 - 3}}});
@@ -733,7 +737,7 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) {
XlaBuilder b(TestName());
Add(ConstantR2<float>(&b, {{1.0, 5.0}, {1.0, 5.0}}),
- ConstantLiteral(&b, *Literal::CreateR3<float>(
+ ConstantLiteral(&b, *LiteralUtil::CreateR3<float>(
{{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})),
/*broadcast_dimensions=*/{1, 2});
diff --git a/tensorflow/compiler/xla/tests/broadcast_test.cc b/tensorflow/compiler/xla/tests/broadcast_test.cc
index 51b9f0d3e3..c7b94b5bba 100644
--- a/tensorflow/compiler/xla/tests/broadcast_test.cc
+++ b/tensorflow/compiler/xla/tests/broadcast_test.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include <memory>
#include <utility>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -37,7 +37,7 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) {
// Test degenerate case of broadcasting a scalar into a scalar.
auto builder = HloComputation::Builder(TestName());
auto input = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
builder.AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::MakeShape(F32, {}), input, {}));
@@ -46,14 +46,14 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) {
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
- EXPECT_TRUE(LiteralTestUtil::Near(*Literal::CreateR0<float>(42.0), *result,
- error_spec_));
+ EXPECT_TRUE(LiteralTestUtil::Near(*LiteralUtil::CreateR0<float>(42.0),
+ *result, error_spec_));
}
XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) {
auto builder = HloComputation::Builder(TestName());
auto input = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
builder.AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::MakeShape(F32, {2, 2}), input, {}));
@@ -63,14 +63,14 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) {
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
EXPECT_TRUE(LiteralTestUtil::Near(
- *Literal::CreateR2<float>({{42.0, 42.0}, {42.0, 42.0}}), *result,
+ *LiteralUtil::CreateR2<float>({{42.0, 42.0}, {42.0, 42.0}}), *result,
error_spec_));
}
XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) {
auto builder = HloComputation::Builder(TestName());
auto input = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({1.0, 2.0, 3.0})));
+ LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0})));
// Broadcast vector in both dimension 0 and dimension 1. Join them in a tuple
// to enable testing of the results.
@@ -86,18 +86,18 @@ XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) {
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
EXPECT_TRUE(LiteralTestUtil::Near(
- *Literal::CreateR2<float>({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}),
+ *LiteralUtil::CreateR2<float>({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}),
LiteralSlice(*result, {0}), error_spec_));
EXPECT_TRUE(LiteralTestUtil::Near(
- *Literal::CreateR2<float>({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}),
+ *LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}),
LiteralSlice(*result, {1}), error_spec_));
}
XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) {
auto builder = HloComputation::Builder(TestName());
auto input = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
builder.AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::MakeShape(F32, {2, 2}), input, {0, 1}));
@@ -106,9 +106,9 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) {
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
- EXPECT_TRUE(
- LiteralTestUtil::Near(*Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
- *result, error_spec_));
+ EXPECT_TRUE(LiteralTestUtil::Near(
+ *LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}), *result,
+ error_spec_));
}
XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) {
@@ -116,7 +116,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) {
// the dimensions, ie transpose.
auto builder = HloComputation::Builder(TestName());
auto input = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
builder.AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::MakeShape(F32, {2, 2}), input, {1, 0}));
@@ -125,15 +125,15 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) {
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
- EXPECT_TRUE(
- LiteralTestUtil::Near(*Literal::CreateR2<float>({{1.0, 3.0}, {2.0, 4.0}}),
- *result, error_spec_));
+ EXPECT_TRUE(LiteralTestUtil::Near(
+ *LiteralUtil::CreateR2<float>({{1.0, 3.0}, {2.0, 4.0}}), *result,
+ error_spec_));
}
XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) {
auto builder = HloComputation::Builder(TestName());
auto input = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
builder.AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::MakeShape(F32, {2, 3, 2}), input, {0, 2}));
@@ -143,15 +143,15 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) {
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
EXPECT_TRUE(LiteralTestUtil::Near(
- *Literal::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
- {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}),
+ *LiteralUtil::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
+ {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}),
*result, error_spec_));
}
TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) {
auto builder = HloComputation::Builder(TestName());
auto input = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>({1.0, 2.0})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({1.0, 2.0})));
// Broadcast vector in dimension 1.
builder.AddInstruction(HloInstruction::CreateBroadcast(
@@ -166,8 +166,9 @@ TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) {
Array2D<float> pz({{1, 2}, {1, 2}});
expected.FillWithPZ(pz);
- EXPECT_TRUE(LiteralTestUtil::Near(
- *Literal::CreateR4FromArray4D<float>(expected), *result, error_spec_));
+ EXPECT_TRUE(
+ LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
+ *result, error_spec_));
}
TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) {
@@ -176,7 +177,7 @@ TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) {
int64 r1_size = input_data.size();
std::iota(input_data.begin(), input_data.end(), 0.0f);
auto input = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>(input_data)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(input_data)));
// Broadcast vector in dimension 3.
builder.AddInstruction(HloInstruction::CreateBroadcast(
@@ -196,8 +197,9 @@ TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) {
}
expected.FillWithYX(yx);
- EXPECT_TRUE(LiteralTestUtil::Near(
- *Literal::CreateR4FromArray4D<float>(expected), *result, error_spec_));
+ EXPECT_TRUE(
+ LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
+ *result, error_spec_));
}
XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) {
@@ -207,7 +209,7 @@ XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) {
std::vector<float> r1_array(64, 42.0);
auto input = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>(r1_array)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(r1_array)));
// Broadcast vector in dimension 1.
builder.AddInstruction(HloInstruction::CreateBroadcast(
@@ -218,14 +220,14 @@ XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) {
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
- EXPECT_TRUE(LiteralTestUtil::Near(*Literal::CreateR4FromArray4D(r4_array),
+ EXPECT_TRUE(LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(r4_array),
*result, error_spec_));
}
TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) {
auto builder = HloComputation::Builder(TestName());
auto input = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
builder.AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::MakeShape(F32, {64, 64, 3, 3}), input, {}));
@@ -238,15 +240,16 @@ TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) {
Array4D<float> expected(64, 64, 3, 3);
expected.Fill(1.0f);
- EXPECT_TRUE(LiteralTestUtil::Near(
- *Literal::CreateR4FromArray4D<float>(expected), *result, error_spec_));
+ EXPECT_TRUE(
+ LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
+ *result, error_spec_));
}
TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) {
auto builder = HloComputation::Builder(TestName());
Array2D<float> to_broadcast({{1.0f, 2.0f}, {3.0f, 4.0f}});
auto input = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2FromArray2D<float>(to_broadcast)));
+ LiteralUtil::CreateR2FromArray2D<float>(to_broadcast)));
// Broadcast vector in dimensions 2 and 3.
builder.AddInstruction(HloInstruction::CreateBroadcast(
@@ -260,8 +263,9 @@ TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) {
Array4D<float> expected(3, 3, 2, 2);
expected.FillWithYX(to_broadcast);
- EXPECT_TRUE(LiteralTestUtil::Near(
- *Literal::CreateR4FromArray4D<float>(expected), *result, error_spec_));
+ EXPECT_TRUE(
+ LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
+ *result, error_spec_));
}
TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) {
@@ -280,7 +284,7 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) {
}
}
auto input = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR3FromArray3D<float>(input_vals)));
+ LiteralUtil::CreateR3FromArray3D<float>(input_vals)));
// Broadcast vector in dimensions 2 and 3.
builder.AddInstruction(HloInstruction::CreateBroadcast(
@@ -291,8 +295,9 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) {
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
- EXPECT_TRUE(LiteralTestUtil::Near(
- *Literal::CreateR4FromArray4D<float>(expected), *result, error_spec_));
+ EXPECT_TRUE(
+ LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
+ *result, error_spec_));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/call_test.cc b/tensorflow/compiler/xla/tests/call_test.cc
index bc64a19ce2..2086e38b91 100644
--- a/tensorflow/compiler/xla/tests/call_test.cc
+++ b/tensorflow/compiler/xla/tests/call_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test_helpers.h"
@@ -76,7 +77,8 @@ class CallOpTest : public ClientLibraryTestBase {
XLA_TEST_F(CallOpTest, CallR0F32IdentityScalar) {
XlaBuilder builder(TestName());
XlaComputation callee = CreateR0F32IdentityComputation();
- auto constant = ConstantLiteral(&builder, *Literal::CreateR0<float>(42.0));
+ auto constant =
+ ConstantLiteral(&builder, *LiteralUtil::CreateR0<float>(42.0));
Call(&builder, callee, {constant});
ComputeAndCompareR0<float>(&builder, 42.0, {}, ErrorSpec(0.01f));
@@ -85,8 +87,8 @@ XLA_TEST_F(CallOpTest, CallR0F32IdentityScalar) {
XLA_TEST_F(CallOpTest, CallR1S0F32AddArray) {
XlaBuilder builder(TestName());
XlaComputation callee = CreateR1S0F32AdditionComputation();
- auto x = ConstantLiteral(&builder, *Literal::CreateR1<float>({}));
- auto y = ConstantLiteral(&builder, *Literal::CreateR1<float>({}));
+ auto x = ConstantLiteral(&builder, *LiteralUtil::CreateR1<float>({}));
+ auto y = ConstantLiteral(&builder, *LiteralUtil::CreateR1<float>({}));
Call(&builder, callee, {x, y});
ComputeAndCompareR1<float>(&builder, {}, {}, ErrorSpec(0.01f));
@@ -95,8 +97,10 @@ XLA_TEST_F(CallOpTest, CallR1S0F32AddArray) {
XLA_TEST_F(CallOpTest, CallR1S2F32AddArray) {
XlaBuilder builder(TestName());
XlaComputation callee = CreateR1S2F32AdditionComputation();
- auto x = ConstantLiteral(&builder, *Literal::CreateR1<float>({1.0f, 2.0f}));
- auto y = ConstantLiteral(&builder, *Literal::CreateR1<float>({2.0f, 3.0f}));
+ auto x =
+ ConstantLiteral(&builder, *LiteralUtil::CreateR1<float>({1.0f, 2.0f}));
+ auto y =
+ ConstantLiteral(&builder, *LiteralUtil::CreateR1<float>({2.0f, 3.0f}));
Call(&builder, callee, {x, y});
ComputeAndCompareR1<float>(&builder, {3.0f, 5.0f}, {}, ErrorSpec(0.01f));
@@ -129,15 +133,15 @@ XLA_TEST_F(CallOpTest, CallTreeTwoDeepBranchFactorThree) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> start,
- client_->TransferToServer(*Literal::CreateR0<float>(1.0f)));
+ client_->TransferToServer(*LiteralUtil::CreateR0<float>(1.0f)));
ComputeAndCompareR0<float>(&builder3, 10.0f, {start.get()}, ErrorSpec(0.0f));
}
XLA_TEST_F(CallOpTest, CallR0F32Tuple) {
XlaBuilder builder(TestName());
XlaComputation callee = CreateR0F32TupleComputation();
- auto elem = Literal::CreateR0<float>(42.0);
- auto tuple = Literal::MakeTuple({elem.get()});
+ auto elem = LiteralUtil::CreateR0<float>(42.0);
+ auto tuple = LiteralUtil::MakeTuple({elem.get()});
Call(&builder, callee, {ConstantLiteral(&builder, *elem)});
ComputeAndCompareTuple(&builder, *tuple, {}, ErrorSpec(0.01f));
diff --git a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc
index 1ad57c075b..0bc8facfe2 100644
--- a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc
+++ b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
@@ -36,7 +36,7 @@ class CheckExecutionArityTest : public ClientLibraryTestBase {};
TEST_F(CheckExecutionArityTest, TwoParamComputationNumArguments) {
XlaBuilder builder("add_two_params");
- auto param_literal = Literal::CreateR1<float>({1.1f, 2.2f});
+ auto param_literal = LiteralUtil::CreateR1<float>({1.1f, 2.2f});
auto p0 = Parameter(&builder, 0, param_literal->shape(), "param0");
auto p1 = Parameter(&builder, 1, param_literal->shape(), "param1");
@@ -85,12 +85,12 @@ XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) {
ASSERT_IS_OK(computation_status.status());
auto computation = computation_status.ConsumeValueOrDie();
- auto f32_literal = Literal::CreateR0<float>(1.1f);
+ auto f32_literal = LiteralUtil::CreateR0<float>(1.1f);
auto f32_data = client_->TransferToServer(*f32_literal).ConsumeValueOrDie();
- auto f32_4_literal = Literal::CreateR1<float>({1.0f, 2.0f, 3.0f, 4.0f});
+ auto f32_4_literal = LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f, 4.0f});
auto f32_4_data =
client_->TransferToServer(*f32_4_literal).ConsumeValueOrDie();
- auto u8_4_literal = Literal::CreateR1U8("hola");
+ auto u8_4_literal = LiteralUtil::CreateR1U8("hola");
auto u8_4_data = client_->TransferToServer(*u8_4_literal).ConsumeValueOrDie();
// Match
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc
index dafd6ebabb..ef784da457 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.cc
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc
@@ -157,7 +157,7 @@ string ClientLibraryTestBase::ExecuteToString(
void ClientLibraryTestBase::ComputeAndCompareR1(
XlaBuilder* builder, const tensorflow::core::Bitmap& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
- std::unique_ptr<Literal> expected_literal = Literal::CreateR1(expected);
+ std::unique_ptr<Literal> expected_literal = LiteralUtil::CreateR1(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
arguments);
}
@@ -295,7 +295,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
std::unique_ptr<Literal> converted_expected;
Shape layout_shape;
if (use_bfloat16_) {
- converted_expected = Literal::ConvertF32ToBF16(expected);
+ converted_expected = LiteralUtil::ConvertF32ToBF16(expected);
expected_ptr = converted_expected.get();
if (shape_with_layout != nullptr) {
layout_shape = *shape_with_layout;
@@ -347,7 +347,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
std::unique_ptr<Literal> converted_expected;
Shape layout_shape;
if (use_bfloat16_) {
- converted_expected = Literal::ConvertF32ToBF16(expected);
+ converted_expected = LiteralUtil::ConvertF32ToBF16(expected);
expected_ptr = converted_expected.get();
if (shape_with_layout != nullptr) {
layout_shape = *shape_with_layout;
@@ -389,7 +389,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1U8(
auto actual = actual_status.ConsumeValueOrDie();
// Turn the expected value into a literal.
- std::unique_ptr<Literal> expected_literal = Literal::CreateR1U8(expected);
+ std::unique_ptr<Literal> expected_literal = LiteralUtil::CreateR1U8(expected);
VLOG(1) << "expected: " << expected_literal->ToString();
VLOG(1) << "actual: " << actual->ToString();
@@ -560,8 +560,9 @@ XlaOp ClientLibraryTestBase::AddParam(const Literal& argument,
XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal,
XlaBuilder* builder) {
- return ConstantLiteral(
- builder, use_bfloat16_ ? *Literal::ConvertF32ToBF16(literal) : literal);
+ return ConstantLiteral(builder, use_bfloat16_
+ ? *LiteralUtil::ConvertF32ToBF16(literal)
+ : literal);
}
std::unique_ptr<GlobalData>
@@ -582,7 +583,7 @@ ClientLibraryTestBase::CreateParameterAndTransferLiteral(
const Literal* param_literal = &literal;
std::unique_ptr<Literal> converted_literal;
if (use_bfloat16_) {
- converted_literal = Literal::ConvertF32ToBF16(literal);
+ converted_literal = LiteralUtil::ConvertF32ToBF16(literal);
param_literal = converted_literal.get();
}
std::unique_ptr<GlobalData> data =
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h
index 5361ae6783..fcc9347db5 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.h
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.h
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -284,7 +285,7 @@ class ClientLibraryTestBase : public ::testing::Test {
template <class T>
XlaOp AddParam(const Array<T>& argument, XlaBuilder* builder) {
- return AddParam(*Literal::CreateFromArray(argument), builder);
+ return AddParam(*LiteralUtil::CreateFromArray(argument), builder);
}
// Creates a constant instruction with the given literal. When the
@@ -299,13 +300,14 @@ class ClientLibraryTestBase : public ::testing::Test {
template <typename NativeT>
XlaOp CreateConstantFromArray(const Array<NativeT>& array,
XlaBuilder* builder) {
- return CreateConstantFromLiteral(*Literal::CreateFromArray(array), builder);
+ return CreateConstantFromLiteral(*LiteralUtil::CreateFromArray(array),
+ builder);
}
// Same as CreateConstantFromArray, but for scalars.
template <typename NativeT>
XlaOp CreateConstantFromScalar(NativeT value, XlaBuilder* builder) {
- return CreateConstantFromLiteral(*Literal::CreateR0<NativeT>(value),
+ return CreateConstantFromLiteral(*LiteralUtil::CreateR0<NativeT>(value),
builder);
}
@@ -410,7 +412,7 @@ void ClientLibraryTestBase::ComputeAndCompareR0(
XlaBuilder* builder, NativeT expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
std::unique_ptr<Literal> expected_literal =
- Literal::CreateR0<NativeT>(expected);
+ LiteralUtil::CreateR0<NativeT>(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
arguments);
}
@@ -426,7 +428,7 @@ void ClientLibraryTestBase::ComputeAndCompareR0(
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
std::unique_ptr<Literal> expected_literal =
- Literal::CreateR0<NativeT>(expected);
+ LiteralUtil::CreateR0<NativeT>(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
arguments, error);
}
@@ -436,7 +438,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1(
XlaBuilder* builder, tensorflow::gtl::ArraySlice<NativeT> expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
std::unique_ptr<Literal> expected_literal =
- Literal::CreateR1<NativeT>(expected);
+ LiteralUtil::CreateR1<NativeT>(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
arguments);
}
@@ -452,7 +454,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1(
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
std::unique_ptr<Literal> expected_literal =
- Literal::CreateR1<NativeT>(expected);
+ LiteralUtil::CreateR1<NativeT>(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
arguments, error);
}
@@ -462,7 +464,7 @@ void ClientLibraryTestBase::ComputeAndCompareR2(
XlaBuilder* builder, const Array2D<NativeT>& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
std::unique_ptr<Literal> expected_literal =
- Literal::CreateR2FromArray2D<NativeT>(expected);
+ LiteralUtil::CreateR2FromArray2D<NativeT>(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
arguments);
}
@@ -478,7 +480,7 @@ void ClientLibraryTestBase::ComputeAndCompareR2(
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
std::unique_ptr<Literal> expected_literal =
- Literal::CreateR2FromArray2D<NativeT>(expected);
+ LiteralUtil::CreateR2FromArray2D<NativeT>(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
arguments, error);
}
@@ -488,7 +490,7 @@ void ClientLibraryTestBase::ComputeAndCompareR3(
XlaBuilder* builder, const Array3D<NativeT>& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
std::unique_ptr<Literal> expected_literal =
- Literal::CreateR3FromArray3D<NativeT>(expected);
+ LiteralUtil::CreateR3FromArray3D<NativeT>(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
arguments);
}
@@ -504,7 +506,7 @@ void ClientLibraryTestBase::ComputeAndCompareR3(
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
std::unique_ptr<Literal> expected_literal =
- Literal::CreateR3FromArray3D<NativeT>(expected);
+ LiteralUtil::CreateR3FromArray3D<NativeT>(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
arguments, error);
}
@@ -514,7 +516,7 @@ void ClientLibraryTestBase::ComputeAndCompareR4(
XlaBuilder* builder, const Array4D<NativeT>& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
std::unique_ptr<Literal> expected_literal =
- Literal::CreateR4FromArray4D<NativeT>(expected);
+ LiteralUtil::CreateR4FromArray4D<NativeT>(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
arguments);
}
@@ -530,7 +532,7 @@ void ClientLibraryTestBase::ComputeAndCompareR4(
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
std::unique_ptr<Literal> expected_literal =
- Literal::CreateR4FromArray4D<NativeT>(expected);
+ LiteralUtil::CreateR4FromArray4D<NativeT>(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
arguments, error);
}
@@ -539,9 +541,9 @@ template <typename NativeT>
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR0Parameter(
NativeT value, int64 parameter_number, const string& name,
XlaBuilder* builder, XlaOp* data_handle) {
- std::unique_ptr<Literal> literal = Literal::CreateR0(value);
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR0(value);
if (use_bfloat16_ && literal->shape().element_type() == F32) {
- literal = Literal::ConvertF32ToBF16(*literal);
+ literal = LiteralUtil::ConvertF32ToBF16(*literal);
}
std::unique_ptr<GlobalData> data =
client_->TransferToServer(*literal).ConsumeValueOrDie();
@@ -553,9 +555,9 @@ template <typename NativeT>
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR1Parameter(
tensorflow::gtl::ArraySlice<NativeT> values, int64 parameter_number,
const string& name, XlaBuilder* builder, XlaOp* data_handle) {
- std::unique_ptr<Literal> literal = Literal::CreateR1(values);
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR1(values);
if (use_bfloat16_ && literal->shape().element_type() == F32) {
- literal = Literal::ConvertF32ToBF16(*literal);
+ literal = LiteralUtil::ConvertF32ToBF16(*literal);
}
std::unique_ptr<GlobalData> data =
client_->TransferToServer(*literal).ConsumeValueOrDie();
@@ -567,9 +569,9 @@ template <typename NativeT>
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR2Parameter(
const Array2D<NativeT>& array_2d, int64 parameter_number,
const string& name, XlaBuilder* builder, XlaOp* data_handle) {
- std::unique_ptr<Literal> literal = Literal::CreateR2FromArray2D(array_2d);
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR2FromArray2D(array_2d);
if (use_bfloat16_ && literal->shape().element_type() == F32) {
- literal = Literal::ConvertF32ToBF16(*literal);
+ literal = LiteralUtil::ConvertF32ToBF16(*literal);
}
std::unique_ptr<GlobalData> data =
client_->TransferToServer(*literal).ConsumeValueOrDie();
@@ -581,9 +583,9 @@ template <typename NativeT>
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR3Parameter(
const Array3D<NativeT>& array_3d, int64 parameter_number,
const string& name, XlaBuilder* builder, XlaOp* data_handle) {
- std::unique_ptr<Literal> literal = Literal::CreateR3FromArray3D(array_3d);
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR3FromArray3D(array_3d);
if (use_bfloat16_ && literal->shape().element_type() == F32) {
- literal = Literal::ConvertF32ToBF16(*literal);
+ literal = LiteralUtil::ConvertF32ToBF16(*literal);
}
std::unique_ptr<GlobalData> data =
client_->TransferToServer(*literal).ConsumeValueOrDie();
diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc
index 831b863998..6ce2f844a3 100644
--- a/tensorflow/compiler/xla/tests/client_test.cc
+++ b/tensorflow/compiler/xla/tests/client_test.cc
@@ -56,7 +56,7 @@ XLA_TEST_F(ClientTest, ExecuteWithLayout) {
client_->Execute(computation, {}, &execution_options));
std::unique_ptr<Literal> expected_literal =
- Literal::CreateR2WithLayout<int32>(
+ LiteralUtil::CreateR2WithLayout<int32>(
{{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(transfer_layout));
TF_ASSERT_OK_AND_ASSIGN(
@@ -112,9 +112,9 @@ XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) {
XlaComputation add_with_one_arg, mul_with_two_args, dot_with_one_arg;
Shape shape = ShapeUtil::MakeShape(S32, {2, 2});
- TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<GlobalData> const_arg,
- client_->TransferToServer(*Literal::CreateR2<int32>({{5, 6}, {7, 8}})));
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> const_arg,
+ client_->TransferToServer(
+ *LiteralUtil::CreateR2<int32>({{5, 6}, {7, 8}})));
XlaBuilder b(TestName() + ".add");
Add(Parameter(&b, 0, shape, "param_0"),
@@ -136,7 +136,7 @@ XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) {
TF_ASSERT_OK_AND_ASSIGN(auto results,
client_->ExecuteParallel(computation_instances));
- auto expected_result = Literal::CreateR2<int32>({{6, 8}, {10, 12}});
+ auto expected_result = LiteralUtil::CreateR2<int32>({{6, 8}, {10, 12}});
TF_ASSERT_OK_AND_ASSIGN(
auto result_literal,
diff --git a/tensorflow/compiler/xla/tests/compilation_cache_test.cc b/tensorflow/compiler/xla/tests/compilation_cache_test.cc
index eb211dd8ff..ff38246286 100644
--- a/tensorflow/compiler/xla/tests/compilation_cache_test.cc
+++ b/tensorflow/compiler/xla/tests/compilation_cache_test.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
@@ -50,7 +50,7 @@ class CompilationCacheTest : public ClientLibraryTestBase {
&execution_profile)
.ConsumeValueOrDie();
EXPECT_TRUE(LiteralTestUtil::Near(
- *Literal::CreateR0<float>(expected_result), *result, error_spec_));
+ *LiteralUtil::CreateR0<float>(expected_result), *result, error_spec_));
EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit());
}
@@ -67,7 +67,7 @@ class CompilationCacheTest : public ClientLibraryTestBase {
std::unique_ptr<Literal> result =
client_->Transfer(*data_handle).ConsumeValueOrDie();
EXPECT_TRUE(LiteralTestUtil::Near(
- *Literal::CreateR2<float>(expected_result), *result, error_spec_));
+ *LiteralUtil::CreateR2<float>(expected_result), *result, error_spec_));
EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit());
}
@@ -89,13 +89,13 @@ XLA_TEST_F(CompilationCacheTest, DISABLED_ComputationCalledMultipleTimes) {
XLA_TEST_F(CompilationCacheTest,
DISABLED_ComputationCalledWithDifferentParameters) {
std::unique_ptr<GlobalData> data_42 =
- client_->TransferToServer(*Literal::CreateR0<float>(42.0f))
+ client_->TransferToServer(*LiteralUtil::CreateR0<float>(42.0f))
.ConsumeValueOrDie();
std::unique_ptr<GlobalData> data_123 =
- client_->TransferToServer(*Literal::CreateR0<float>(123.0f))
+ client_->TransferToServer(*LiteralUtil::CreateR0<float>(123.0f))
.ConsumeValueOrDie();
std::unique_ptr<GlobalData> data_456 =
- client_->TransferToServer(*Literal::CreateR0<float>(456.0f))
+ client_->TransferToServer(*LiteralUtil::CreateR0<float>(456.0f))
.ConsumeValueOrDie();
XlaBuilder builder(TestName());
@@ -143,12 +143,12 @@ XLA_TEST_F(CompilationCacheTest, DISABLED_DifferentParameterLayouts) {
// layouts. Use these arrays as parameters to a simple computation. If the
// layout of the array changes then computation should be recompiled (cache
// miss).
- auto rowmaj_array = Literal::CreateR2WithLayout(
+ auto rowmaj_array = LiteralUtil::CreateR2WithLayout(
{{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({1, 0}));
auto rowmaj_handle =
client_->TransferToServer(*rowmaj_array).ConsumeValueOrDie();
- auto colmaj_array = Literal::CreateR2WithLayout(
+ auto colmaj_array = LiteralUtil::CreateR2WithLayout(
{{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1}));
auto colmaj_handle =
client_->TransferToServer(*colmaj_array).ConsumeValueOrDie();
diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc
index 1a396b090c..64bf8b3b38 100644
--- a/tensorflow/compiler/xla/tests/compute_constant_test.cc
+++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -207,7 +207,7 @@ TEST_F(ComputeConstantTest, NonScalarAdd) {
TF_ASSERT_OK_AND_ASSIGN(auto computed,
ComputeConstantLiteral(client, computation, &b));
std::unique_ptr<Literal> expected_literal =
- Literal::CreateR1<int32>({4, 6});
+ LiteralUtil::CreateR1<int32>({4, 6});
EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
}
}
@@ -221,7 +221,7 @@ TEST_F(ComputeConstantTest, IntegerDivide) {
TF_ASSERT_OK_AND_ASSIGN(auto computed,
ComputeConstantLiteral(client, computation, &b));
- std::unique_ptr<Literal> expected_literal = Literal::CreateR0<int32>(5);
+ std::unique_ptr<Literal> expected_literal = LiteralUtil::CreateR0<int32>(5);
EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
}
}
@@ -242,8 +242,8 @@ XLA_TEST_F(ComputeConstantTest, Layout) {
&b, &layout_proto));
std::unique_ptr<Literal> expected_literal =
- Literal::CreateR2WithLayout<int32>({{11, 22}, {33, 44}},
- LayoutUtil::MakeLayout(layout));
+ LiteralUtil::CreateR2WithLayout<int32>(
+ {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(layout));
ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts(
expected_literal->shape(), computed->shape()));
EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc
index 1161b560b7..9f288634c0 100644
--- a/tensorflow/compiler/xla/tests/concat_test.cc
+++ b/tensorflow/compiler/xla/tests/concat_test.cc
@@ -534,8 +534,8 @@ TEST_P(ConcatR2BinaryTest, DoIt) {
// concat
XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) {
auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {});
- auto x_literal = Literal::CreateR0<float>(2.f);
- auto y_literal = Literal::CreateR0<float>(3.f);
+ auto x_literal = LiteralUtil::CreateR0<float>(2.f);
+ auto y_literal = LiteralUtil::CreateR0<float>(3.f);
auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
@@ -556,9 +556,9 @@ XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) {
// produces the correct result in rank 1.
XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) {
auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {});
- auto x_literal = Literal::CreateR1<float>({2.0f, 3.0f, 5.0f, 6.0f});
- auto y_literal = Literal::CreateR0<float>(1.5f);
- auto z_literal = Literal::CreateR0<float>(5.5f);
+ auto x_literal = LiteralUtil::CreateR1<float>({2.0f, 3.0f, 5.0f, 6.0f});
+ auto y_literal = LiteralUtil::CreateR0<float>(1.5f);
+ auto z_literal = LiteralUtil::CreateR0<float>(5.5f);
auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie();
@@ -584,9 +584,9 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) {
XLA_TEST_F(ConcatTest, ConcatBroadcastArgumentR3) {
auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {});
Array3D<float> x3d(3, 5, 7, 3.14f);
- auto x_literal = Literal::CreateR3FromArray3D<float>(x3d);
- auto y_literal = Literal::CreateR0<float>(1.5f);
- auto z_literal = Literal::CreateR0<float>(5.5f);
+ auto x_literal = LiteralUtil::CreateR3FromArray3D<float>(x3d);
+ auto y_literal = LiteralUtil::CreateR0<float>(1.5f);
+ auto z_literal = LiteralUtil::CreateR0<float>(5.5f);
auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie();
diff --git a/tensorflow/compiler/xla/tests/conditional_test.cc b/tensorflow/compiler/xla/tests/conditional_test.cc
index ee3c83039b..35f1400fb2 100644
--- a/tensorflow/compiler/xla/tests/conditional_test.cc
+++ b/tensorflow/compiler/xla/tests/conditional_test.cc
@@ -344,8 +344,8 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) {
ComputeAndCompareTuple(
&builder,
- *Literal::MakeTuple({Literal::CreateR0<float>(12.0f).get(),
- Literal::CreateR0<float>(25.0f).get()}),
+ *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(12.0f).get(),
+ LiteralUtil::CreateR0<float>(25.0f).get()}),
{}, error_spec_);
}
@@ -361,8 +361,9 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) {
ComputeAndCompareTuple(
&builder,
- *Literal::MakeTuple({Literal::CreateR1<float>({13.0f, 16.0f}).get(),
- Literal::CreateR1<float>({26.0f, 30.0f}).get()}),
+ *LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR1<float>({13.0f, 16.0f}).get(),
+ LiteralUtil::CreateR1<float>({26.0f, 30.0f}).get()}),
{}, error_spec_);
}
@@ -399,9 +400,10 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) {
ComputeAndCompareTuple(
&builder,
- *Literal::MakeTuple({Literal::CreateR0<bool>(true).get(),
- Literal::CreateR0<float>(12.2f).get(),
- Literal::CreateR1<float>({12.8f, 14.6f}).get()}),
+ *LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR0<bool>(true).get(),
+ LiteralUtil::CreateR0<float>(12.2f).get(),
+ LiteralUtil::CreateR1<float>({12.8f, 14.6f}).get()}),
{}, error_spec_);
}
@@ -443,12 +445,14 @@ XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) {
ComputeAndCompareTuple(
&builder,
- *Literal::MakeTuple(
- {Literal::MakeTuple({Literal::CreateR0<float>(46.6f).get(),
- Literal::CreateR1<float>({54.4f, 58.4f}).get()})
+ *LiteralUtil::MakeTuple(
+ {LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR0<float>(46.6f).get(),
+ LiteralUtil::CreateR1<float>({54.4f, 58.4f}).get()})
.get(),
- Literal::MakeTuple({Literal::CreateR1<float>({62.1f, 67.4f}).get(),
- Literal::CreateR0<float>(9.3f).get()})
+ LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR1<float>({62.1f, 67.4f}).get(),
+ LiteralUtil::CreateR0<float>(9.3f).get()})
.get()}),
{}, error_spec_);
}
@@ -607,8 +611,8 @@ XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) {
ComputeAndCompareTuple(
&builder,
- *Literal::MakeTuple({Literal::CreateR0<float>(a).get(),
- Literal::CreateR0<float>(b).get()}),
+ *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(a).get(),
+ LiteralUtil::CreateR0<float>(b).get()}),
{}, error_spec_);
};
diff --git a/tensorflow/compiler/xla/tests/constants_test.cc b/tensorflow/compiler/xla/tests/constants_test.cc
index cc5d3b1176..71d72a9828 100644
--- a/tensorflow/compiler/xla/tests/constants_test.cc
+++ b/tensorflow/compiler/xla/tests/constants_test.cc
@@ -110,8 +110,8 @@ TEST_F(ConstantsTest, Small_2x2) {
TEST_F(ConstantsTest, Empty_3x0x2) {
XlaBuilder builder(TestName());
- ConstantLiteral(
- &builder, *Literal::CreateR3FromArray3D<float>(Array3D<float>(3, 0, 2)));
+ ConstantLiteral(&builder, *LiteralUtil::CreateR3FromArray3D<float>(
+ Array3D<float>(3, 0, 2)));
ComputeAndCompareR3<float>(&builder, Array3D<float>(3, 0, 2), {});
}
@@ -126,7 +126,7 @@ TEST_F(ConstantsTest, Small_2x2x2) {
{{5.f, 6.f}, // y0
{7.f, 8.f}}, // y1
});
- ConstantLiteral(&builder, *Literal::CreateR3FromArray3D<float>(array3d));
+ ConstantLiteral(&builder, *LiteralUtil::CreateR3FromArray3D<float>(array3d));
ComputeAndCompareR3<float>(&builder, array3d, {});
}
@@ -141,7 +141,7 @@ TEST_F(ConstantsTest, Small_3x2x1x1) {
});
input_array.FillWithPZ(pz);
std::unique_ptr<Literal> input_literal =
- Literal::CreateR4FromArray4D(input_array);
+ LiteralUtil::CreateR4FromArray4D(input_array);
{
XlaBuilder builder(TestName());
@@ -159,22 +159,23 @@ TEST_F(ConstantsTest, Small_3x2x1x1) {
// TODO(b/29263943): Support tuple constants.
TEST_F(ConstantsTest, DISABLED_TupleConstant) {
XlaBuilder builder(TestName());
- ConstantLiteral(&builder, *Literal::MakeTuple(
- {Literal::CreateR2<float>({{1.0}, {2.0}}).get(),
- Literal::CreateR1<float>({2.0, 42}).get()}));
+ ConstantLiteral(&builder,
+ *LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR2<float>({{1.0}, {2.0}}).get(),
+ LiteralUtil::CreateR1<float>({2.0, 42}).get()}));
std::unique_ptr<Literal> result =
ExecuteAndTransfer(&builder, {}).ConsumeValueOrDie();
- LiteralTestUtil::ExpectR2Near<float>(
- {{1.0}, {2.0}}, LiteralSlice(*result, {0}), error_spec_);
- LiteralTestUtil::ExpectR1Near<float>(
- {2.0, 42.0}, LiteralSlice(*result, {1}), error_spec_);
+ LiteralTestUtil::ExpectR2Near<float>({{1.0}, {2.0}},
+ LiteralSlice(*result, {0}), error_spec_);
+ LiteralTestUtil::ExpectR1Near<float>({2.0, 42.0}, LiteralSlice(*result, {1}),
+ error_spec_);
}
TEST_F(ConstantsTest, Token) {
XlaBuilder builder(TestName());
- ConstantLiteral(&builder, *Literal::CreateToken());
+ ConstantLiteral(&builder, *LiteralUtil::CreateToken());
// TODO(b/80000000): tokens cannot be returned from computations.
Tuple(&builder, {});
TF_ASSERT_OK(Execute(&builder, {}).status());
diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc
index 292942a49e..dca57fd1c7 100644
--- a/tensorflow/compiler/xla/tests/convert_test.cc
+++ b/tensorflow/compiler/xla/tests/convert_test.cc
@@ -145,7 +145,7 @@ XLA_TEST_F(ConvertTest, ConvertR1S64ToR1F32) {
static_cast<int64>(0x8000008000000000LL),
static_cast<int64>(0x8000010000000000LL),
};
- std::unique_ptr<Literal> arg_literal = Literal::CreateR1<int64>({arg});
+ std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<int64>({arg});
auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
@@ -164,7 +164,7 @@ XLA_TEST_F(ConvertTest, ConvertR1U32ToR1F32) {
std::vector<uint32> arg{0, 1, 0x1000, 0x7fffffff,
0x80000000, 0x80000001, 0x80000002, 0x80000003,
0x80000080, 0x80000081, 0x80000082, 0xFFFFFFFF};
- std::unique_ptr<Literal> arg_literal = Literal::CreateR1<uint32>({arg});
+ std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<uint32>({arg});
auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
@@ -182,7 +182,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) {
XlaBuilder builder(TestName());
std::vector<float> arg{0.0f, 1.0f, 16777216.0f,
16777218.0f, 2147483647.0f, 4294967040.0f};
- std::unique_ptr<Literal> arg_literal = Literal::CreateR1<float>({arg});
+ std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<float>({arg});
auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
@@ -199,7 +199,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) {
XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) {
XlaBuilder builder(TestName());
std::vector<uint32> arg{0, 1, 0x1000, 0x7fffffff, 0x80000082, 0xFFFFFFFF};
- std::unique_ptr<Literal> arg_literal = Literal::CreateR1<uint32>({arg});
+ std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<uint32>({arg});
auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
@@ -216,7 +216,7 @@ XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) {
XLA_TEST_F(ConvertTest, ConvertR1S32ToR1S64) {
XlaBuilder builder(TestName());
std::vector<int32> arg{0, 1, 0x1000, -1, -0x1000};
- std::unique_ptr<Literal> arg_literal = Literal::CreateR1<int32>({arg});
+ std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<int32>({arg});
auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
@@ -253,7 +253,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1S64) {
9223370937343148032.f,
-9223371487098961920.f,
-9223370937343148032.f};
- std::unique_ptr<Literal> arg_literal = Literal::CreateR1<float>({arg});
+ std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<float>({arg});
auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
@@ -391,7 +391,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F16ToR1F32) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> dot_lhs_handle,
- client_->TransferToServer(*Literal::CreateR1<half>(input)));
+ client_->TransferToServer(*LiteralUtil::CreateR1<half>(input)));
XlaBuilder builder(TestName());
ConvertElementType(
@@ -411,7 +411,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F16) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> dot_lhs_handle,
- client_->TransferToServer(*Literal::CreateR1<float>(input)));
+ client_->TransferToServer(*LiteralUtil::CreateR1<float>(input)));
XlaBuilder builder(TestName());
ConvertElementType(
diff --git a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc
index 7605ebf4c0..944366410b 100644
--- a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc
@@ -93,7 +93,8 @@ XLA_TEST_F(ConvolutionDimensionNumbersTest,
auto weight_array = MakeUnique<Array4D<float>>(4, 3, 1, 1);
weight_array->FillWithMultiples(0.2);
auto weight_data =
- client_->TransferToServer(*Literal::CreateR4FromArray4D(*weight_array))
+ client_
+ ->TransferToServer(*LiteralUtil::CreateR4FromArray4D(*weight_array))
.ConsumeValueOrDie();
XlaBuilder builder(TestName());
diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc
index 0f6d54d042..a8b8f74ca9 100644
--- a/tensorflow/compiler/xla/tests/convolution_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_test.cc
@@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/padding.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -123,8 +123,8 @@ class Convolve_1x1x1x2_1x1x1x2_Valid : public ConvolutionTest {
}));
ComputeAndCompare(&builder,
- {std::move(*Literal::CreateFromArray(input_data)),
- std::move(*Literal::CreateFromArray(filter_data))},
+ {std::move(*LiteralUtil::CreateFromArray(input_data)),
+ std::move(*LiteralUtil::CreateFromArray(filter_data))},
error_spec_);
}
};
@@ -157,8 +157,8 @@ class Convolve_1x1x4x4_1x1x2x2_Valid : public ConvolutionTest {
{7.0f, 8.0f},
}));
ComputeAndCompare(&builder,
- {std::move(*Literal::CreateFromArray(input_data)),
- std::move(*Literal::CreateFromArray(filter_data))},
+ {std::move(*LiteralUtil::CreateFromArray(input_data)),
+ std::move(*LiteralUtil::CreateFromArray(filter_data))},
error_spec_);
}
};
@@ -192,8 +192,8 @@ class Convolve_1x1x4x4_1x1x2x2_Same : public ConvolutionTest {
}));
ComputeAndCompare(&builder,
- {std::move(*Literal::CreateFromArray(input_data)),
- std::move(*Literal::CreateFromArray(filter_data))},
+ {std::move(*LiteralUtil::CreateFromArray(input_data)),
+ std::move(*LiteralUtil::CreateFromArray(filter_data))},
error_spec_);
}
};
@@ -224,8 +224,8 @@ class Convolve_1x1x4x4_1x1x3x3_Same : public ConvolutionTest {
{{5.0f, 6.0f, 7.0f}, {8.0f, 9.0f, 10.0f}, {11.0f, 12.0f, 13.0f}}));
// clang-format on
ComputeAndCompare(&builder,
- {std::move(*Literal::CreateFromArray(input_data)),
- std::move(*Literal::CreateFromArray(filter_data))},
+ {std::move(*LiteralUtil::CreateFromArray(input_data)),
+ std::move(*LiteralUtil::CreateFromArray(filter_data))},
error_spec_);
}
};
@@ -249,10 +249,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) {
Array3D<float> expected({{{510, 610, 710, 810}}});
auto input_literal =
- client_->TransferToServer(*Literal::CreateR3FromArray3D(input))
+ client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
.ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*Literal::CreateR3FromArray3D(filter))
+ client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
.ConsumeValueOrDie();
ComputeAndCompareR3<float>(&builder, expected,
@@ -284,10 +284,10 @@ class Convolve1D_1x2x5_1x2x2_WithRHSDilation : public ConvolutionTest {
Array3D<T> expected({{{570.0f, 670.0f, 770.0f}}});
auto input_literal =
- client_->TransferToServer(*Literal::CreateR3FromArray3D(input))
+ client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
.ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*Literal::CreateR3FromArray3D(filter))
+ client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
.ConsumeValueOrDie();
ComputeAndCompareR3<T>(&builder, expected,
@@ -319,10 +319,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSDilation) {
Array3D<float> expected({{{190, 320, 230, 380, 270, 440, 310, 500}}});
auto input_literal =
- client_->TransferToServer(*Literal::CreateR3FromArray3D(input))
+ client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
.ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*Literal::CreateR3FromArray3D(filter))
+ client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
.ConsumeValueOrDie();
ComputeAndCompareR3<float>(&builder, expected,
@@ -350,10 +350,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSAndRHSDilation) {
Array3D<float> expected({{{510, 0, 610, 0, 710, 0, 810}}});
auto input_literal =
- client_->TransferToServer(*Literal::CreateR3FromArray3D(input))
+ client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
.ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*Literal::CreateR3FromArray3D(filter))
+ client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
.ConsumeValueOrDie();
ComputeAndCompareR3<float>(&builder, expected,
@@ -386,10 +386,10 @@ class Convolve1D_1x2x5_1x2x2_WithPadding : public ConvolutionTest {
{{{0.0f, 260.0f, 510.0f, 610.0f, 710.0f, 810.0f, 350.0f, 0.0f}}});
auto input_literal =
- client_->TransferToServer(*Literal::CreateR3FromArray3D(input))
+ client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
.ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*Literal::CreateR3FromArray3D(filter))
+ client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
.ConsumeValueOrDie();
ComputeAndCompareR3<T>(&builder, expected,
@@ -434,15 +434,15 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) {
std::vector<float> input_elems(ShapeUtil::ElementsIn(input_shape));
iota(input_elems.begin(), input_elems.end(), 1.0f);
- auto input_r1 = Literal::CreateR1<float>(input_elems);
+ auto input_r1 = LiteralUtil::CreateR1<float>(input_elems);
auto input_r5 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
std::vector<float> filter_elems(ShapeUtil::ElementsIn(filter_shape));
iota(filter_elems.begin(), filter_elems.end(), 1.0f);
- auto filter_r1 = Literal::CreateR1<float>(filter_elems);
+ auto filter_r1 = LiteralUtil::CreateR1<float>(filter_elems);
auto filter_r5 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
- auto expected_r1 = Literal::CreateR1<float>(
+ auto expected_r1 = LiteralUtil::CreateR1<float>(
{19554, 19962, 20370, 22110, 22590, 23070, 34890, 35730, 36570, 37446,
38358, 39270, 50226, 51498, 52770, 52782, 54126, 55470});
auto expected_r5 = expected_r1->Reshape({1, 3, 1, 2, 3}).ConsumeValueOrDie();
@@ -497,15 +497,15 @@ class Convolve2D_1x3x3x5_3x3x5x5_Valid : public ConvolutionTest {
std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
iota_int_init_value(input_elems, 1);
- auto input_r1 = Literal::CreateR1<T>(input_elems);
+ auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
iota_int_init_value(filter_elems, 1);
- auto filter_r1 = Literal::CreateR1<T>(filter_elems);
+ auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
- auto expected_r1 = Literal::CreateR1<T>(
+ auto expected_r1 = LiteralUtil::CreateR1<T>(
{static_cast<T>(92115), static_cast<T>(93150), static_cast<T>(94185)});
auto expected_r4 = expected_r1->Reshape({1, 1, 1, 3}).ConsumeValueOrDie();
@@ -561,8 +561,8 @@ XLA_TEST_P(ConvolveWithAndWithoutCanonicalization,
expected_result.Fill(0);
ComputeAndCompare(&builder,
- {std::move(*Literal::CreateFromArray(param0)),
- std::move(*Literal::CreateFromArray(param1))},
+ {std::move(*LiteralUtil::CreateFromArray(param0)),
+ std::move(*LiteralUtil::CreateFromArray(param1))},
error_spec_);
}
@@ -617,18 +617,18 @@ class Convolve1D1WindowTestBase
std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
static_cast<T>(1.0f));
- auto input_r1 = Literal::CreateR1<T>(input_elems);
+ auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
auto input_r3 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
static_cast<T>(1.0f));
- auto filter_r1 = Literal::CreateR1<T>(filter_elems);
+ auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
auto filter_r3 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
std::vector<T> expect_elems(batch * output_feature * num_windows,
static_cast<T>(window_size * input_feature));
- auto expected_r1 = Literal::CreateR1<T>(expect_elems);
+ auto expected_r1 = LiteralUtil::CreateR1<T>(expect_elems);
auto expected_r3 =
expected_r1->Reshape({batch, num_windows, output_feature})
.ConsumeValueOrDie();
@@ -737,8 +737,8 @@ XLA_TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) {
}));
ComputeAndCompare(&builder,
- {std::move(*Literal::CreateFromArray(input_data)),
- std::move(*Literal::CreateFromArray(filter_data))},
+ {std::move(*LiteralUtil::CreateFromArray(input_data)),
+ std::move(*LiteralUtil::CreateFromArray(filter_data))},
error_spec_);
}
@@ -761,8 +761,8 @@ XLA_TEST_F(ConvolutionTest, NoCudnnAlgorithmPicker) {
filter_data.FillIota(10);
ComputeAndCompare(&builder,
- {std::move(*Literal::CreateFromArray(input_data)),
- std::move(*Literal::CreateFromArray(filter_data))});
+ {std::move(*LiteralUtil::CreateFromArray(input_data)),
+ std::move(*LiteralUtil::CreateFromArray(filter_data))});
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/convolution_variants_test.cc b/tensorflow/compiler/xla/tests/convolution_variants_test.cc
index c31d033bb0..8792e7781b 100644
--- a/tensorflow/compiler/xla/tests/convolution_variants_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_variants_test.cc
@@ -28,7 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/padding.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
@@ -1333,17 +1333,17 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding1D) {
XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding3D) {
XlaBuilder builder(TestName());
- auto gradients_flat = Literal::CreateR1<float>({1});
+ auto gradients_flat = LiteralUtil::CreateR1<float>({1});
auto gradients_literal =
gradients_flat->Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie();
auto gradients = ConstantLiteral(&builder, *gradients_literal);
- auto weights_flat = Literal::CreateR1<float>({1, 10, 100});
+ auto weights_flat = LiteralUtil::CreateR1<float>({1, 10, 100});
auto weights_literal =
weights_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
auto weights = ConstantLiteral(&builder, *weights_literal);
- auto expected_flat = Literal::CreateR1<float>({10});
+ auto expected_flat = LiteralUtil::CreateR1<float>({10});
auto expected_literal =
expected_flat->Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie();
@@ -1357,17 +1357,17 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding3D) {
XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) {
XlaBuilder builder(TestName());
- auto activations_flat = Literal::CreateR1<float>({1, 2, 3, 4});
+ auto activations_flat = LiteralUtil::CreateR1<float>({1, 2, 3, 4});
auto activations_literal =
activations_flat->Reshape({1, 1, 1, 1, 4}).ConsumeValueOrDie();
auto activations = ConstantLiteral(&builder, *activations_literal);
- auto gradients_flat = Literal::CreateR1<float>({100, 10, 1});
+ auto gradients_flat = LiteralUtil::CreateR1<float>({100, 10, 1});
auto gradients_literal =
gradients_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
auto gradients = ConstantLiteral(&builder, *gradients_literal);
- auto expected_flat = Literal::CreateR1<float>({13, 24, 130});
+ auto expected_flat = LiteralUtil::CreateR1<float>({13, 24, 130});
auto expected_literal =
expected_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc
index fef42885e5..1dc6ff0f4f 100644
--- a/tensorflow/compiler/xla/tests/copy_test.cc
+++ b/tensorflow/compiler/xla/tests/copy_test.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -58,37 +58,38 @@ class CopyOpTest : public HloTestBase {
};
XLA_TEST_F(CopyOpTest, CopyR0Bool) {
- TestCopyOp(*Literal::CreateR0<bool>(true));
+ TestCopyOp(*LiteralUtil::CreateR0<bool>(true));
}
XLA_TEST_F(CopyOpTest, CopyR1S0U32) {
- TestCopyOp(*Literal::CreateR1<uint32>({}));
+ TestCopyOp(*LiteralUtil::CreateR1<uint32>({}));
}
XLA_TEST_F(CopyOpTest, CopyR1S3U32) {
- TestCopyOp(*Literal::CreateR1<uint32>({1, 2, 3}));
+ TestCopyOp(*LiteralUtil::CreateR1<uint32>({1, 2, 3}));
}
XLA_TEST_F(CopyOpTest, CopyR3F32_2x2x3) {
- TestCopyOp(*Literal::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
- {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
+ TestCopyOp(
+ *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
+ {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
}
XLA_TEST_F(CopyOpTest, CopyR4S32_2x2x3x2) {
- TestCopyOp(*Literal::CreateR4(
+ TestCopyOp(*LiteralUtil::CreateR4(
{{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}},
{{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}}));
}
XLA_TEST_F(CopyOpTest, CopyR4S32_0x2x3x2) {
- TestCopyOp(*Literal::CreateR4FromArray4D(Array4D<int32>(0, 2, 3, 2)));
+ TestCopyOp(*LiteralUtil::CreateR4FromArray4D(Array4D<int32>(0, 2, 3, 2)));
}
XLA_TEST_F(CopyOpTest, CopyParameterScalar) {
auto builder = HloComputation::Builder(TestName());
// Copy literal to device to use as parameter.
- auto literal = Literal::CreateR0<float>(42.0);
+ auto literal = LiteralUtil::CreateR0<float>(42.0);
Shape shape = literal->shape();
auto param0 = builder.AddInstruction(
@@ -109,7 +110,7 @@ XLA_TEST_F(CopyOpTest, CopyParameterScalar) {
XLA_TEST_F(CopyOpTest, CopyConstantR2Twice) {
auto builder = HloComputation::Builder(TestName());
- auto literal = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ auto literal = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
auto constant = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
@@ -131,7 +132,7 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) {
HloComputation::Builder builder(TestName());
std::unique_ptr<Literal> literal =
- Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
// Reverse the minor-to-major order of the literal.
Layout* literal_layout =
literal->mutable_shape_do_not_use()->mutable_layout();
@@ -168,7 +169,7 @@ void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) {
HloComputation::Builder builder(TestName());
- std::unique_ptr<Literal> literal = Literal::CreateR3FromArray3D(a);
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR3FromArray3D(a);
HloInstruction* constant = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
@@ -202,7 +203,7 @@ void CopyOpTest::TestCopyConstantLayoutR4(
HloComputation::Builder builder(TestName());
- std::unique_ptr<Literal> literal = Literal::CreateR4FromArray4D(a);
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR4FromArray4D(a);
HloInstruction* constant = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
diff --git a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc
index b151187c4b..d12a4e7fcd 100644
--- a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc
+++ b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
@@ -45,7 +45,7 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, OneOperand) {
})";
auto module =
ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie();
- auto literal = Literal::CreateR1<float>({1, 2, 3});
+ auto literal = LiteralUtil::CreateR1<float>({1, 2, 3});
EXPECT_EQ(*literal, *ExecuteAndTransfer(std::move(module), {literal.get()}));
}
@@ -66,10 +66,10 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) {
})";
auto module =
ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie();
- auto literal0 = Literal::CreateR1<float>({1, 2, 3});
- auto literal1 = Literal::CreateR1<float>({10, 20});
+ auto literal0 = LiteralUtil::CreateR1<float>({1, 2, 3});
+ auto literal1 = LiteralUtil::CreateR1<float>({10, 20});
EXPECT_EQ(
- *Literal::MakeTuple({literal0.get(), literal1.get()}),
+ *LiteralUtil::MakeTuple({literal0.get(), literal1.get()}),
*ExecuteAndTransfer(std::move(module), {literal0.get(), literal1.get()}));
}
@@ -93,9 +93,9 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, ConstantOperand) {
})";
auto module =
ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie();
- auto literal0 = Literal::CreateR1<float>({1, 2, 3});
- auto literal1 = Literal::CreateR1<float>({10, 20});
- EXPECT_EQ(*Literal::MakeTuple({literal0.get(), literal1.get()}),
+ auto literal0 = LiteralUtil::CreateR1<float>({1, 2, 3});
+ auto literal1 = LiteralUtil::CreateR1<float>({10, 20});
+ EXPECT_EQ(*LiteralUtil::MakeTuple({literal0.get(), literal1.get()}),
*ExecuteAndTransfer(std::move(module), {literal0.get()}));
}
diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc
index d1516a28b0..90f3d1b874 100644
--- a/tensorflow/compiler/xla/tests/custom_call_test.cc
+++ b/tensorflow/compiler/xla/tests/custom_call_test.cc
@@ -74,7 +74,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR0F32Add2)) {
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
builder.AddInstruction(
HloInstruction::CreateCustomCall(r0f32_, {constant}, "R0F32Add2"));
@@ -95,7 +95,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) {
array(1, 1) = 4.0f;
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2FromArray2D(array)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2FromArray2D(array)));
builder.AddInstruction(
HloInstruction::CreateCustomCall(r0f32_, {constant}, "R2F32ReduceSum"));
@@ -111,7 +111,7 @@ XLA_TEST_F(CustomCallTest,
auto b = HloComputation::Builder(TestName());
auto input = b.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2FromArray2D(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2FromArray2D(
Array2D<float>{{1.0f, 2.0f}, {3.0f, 4.0f}})));
auto incremented = b.AddInstruction(HloInstruction::CreateCustomCall(
ShapeUtil::MakeShape(F32, {1, 2, 2}), {input}, "Add1ToValues"));
diff --git a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc
index acba67491d..a6a233e71a 100644
--- a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc
+++ b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
@@ -171,7 +171,7 @@ TEST_F(DeconstructTupleTest, DeconstructNonTuple) {
XLA_TEST_F(DeconstructTupleTest, DeconstructTupleFromParam) {
XlaBuilder builder(TestName());
std::unique_ptr<Literal> param0_literal =
- Literal::CreateR1<float>({3.14f, -100.25f});
+ LiteralUtil::CreateR1<float>({3.14f, -100.25f});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
auto p = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "param0");
diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc
index cf2e645d47..d86fd7cc2d 100644
--- a/tensorflow/compiler/xla/tests/dot_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc
@@ -67,15 +67,16 @@ XLA_TEST_F(DotOperationTest, DotOfInputTupleElem) {
XlaOp param;
auto param_data = CreateParameterAndTransferLiteral(
0,
- *Literal::MakeTuple({Literal::CreateR2<float>({{1, 2}, {3, 4}}).get(),
- Literal::CreateR2<float>({{5, 6}, {7, 8}}).get()}),
+ *LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}).get(),
+ LiteralUtil::CreateR2<float>({{5, 6}, {7, 8}}).get()}),
"arg0", &builder, &param);
auto lhs = GetTupleElement(param, 0);
auto rhs = GetTupleElement(param, 1);
Dot(lhs, rhs);
ComputeAndCompareLiteral(&builder,
- *Literal::CreateR2<float>({{19, 22}, {43, 50}}),
+ *LiteralUtil::CreateR2<float>({{19, 22}, {43, 50}}),
{param_data.get()});
}
@@ -194,11 +195,11 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, FusedDot) {
auto lhs_handle =
this->client_
- ->TransferToServer(*Literal::CreateR2FromArray2D<T>(
+ ->TransferToServer(*LiteralUtil::CreateR2FromArray2D<T>(
{{1.0f, 2.0f, 3.0f, 4.0f}, {-1.0f, -2.0f, -3.0f, -4.0f}}))
.ConsumeValueOrDie();
auto rhs_handle = this->client_
- ->TransferToServer(*Literal::CreateR2FromArray2D<T>(
+ ->TransferToServer(*LiteralUtil::CreateR2FromArray2D<T>(
{{1.0f}, {2.0f}, {3.0f}, {4.0f}}))
.ConsumeValueOrDie();
@@ -217,14 +218,14 @@ class SquareMatrixDot : public DotOperationTest {
void TestImpl(bool lhs_row_major, bool rhs_row_major) {
auto lhs_handle =
client_
- ->TransferToServer(*Literal::CreateFromArrayWithLayout<T>(
+ ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout<T>(
{{1.0f, 2.0f}, {3.0f, -4.0f}},
LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(lhs_row_major))))
.ConsumeValueOrDie();
auto rhs_handle =
client_
- ->TransferToServer(*Literal::CreateFromArrayWithLayout<T>(
+ ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout<T>(
{{1.0f, 6.0f}, {7.0f, -4.0f}},
LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(rhs_row_major))))
@@ -286,9 +287,10 @@ void ParametricDotTest::TestImpl() {
std::unique_ptr<Array2D<NativeT>> dot_lhs_data =
MakeLinspaceArray2D<NativeT>(0.0, 1.0, param.m, param.k);
- std::unique_ptr<Literal> dot_lhs_lit = Literal::CreateR2FromArray2DWithLayout(
- *dot_lhs_data, LayoutUtil::MakeLayout(
- MinorToMajorForIsRowMajor(param.dot_lhs_row_major)));
+ std::unique_ptr<Literal> dot_lhs_lit =
+ LiteralUtil::CreateR2FromArray2DWithLayout(
+ *dot_lhs_data, LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(
+ param.dot_lhs_row_major)));
std::unique_ptr<GlobalData> dot_lhs_handle =
client_->TransferToServer(*dot_lhs_lit).ConsumeValueOrDie();
@@ -297,7 +299,7 @@ void ParametricDotTest::TestImpl() {
Layout rhs_layout = LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(param.dot_rhs_row_major));
std::unique_ptr<Literal> dot_rhs_lit =
- Literal::CreateR2FromArray2DWithLayout(*dot_rhs_data, rhs_layout);
+ LiteralUtil::CreateR2FromArray2DWithLayout(*dot_rhs_data, rhs_layout);
std::unique_ptr<GlobalData> dot_rhs_handle =
client_->TransferToServer(*dot_rhs_lit).ConsumeValueOrDie();
@@ -307,7 +309,7 @@ void ParametricDotTest::TestImpl() {
if (param.has_addend) {
addend_data = MakeLinspaceArray2D<NativeT>(0.0, 1.0, param.m, param.n);
- addend_lit = Literal::CreateR2FromArray2DWithLayout(
+ addend_lit = LiteralUtil::CreateR2FromArray2DWithLayout(
*addend_data, LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(param.addend_row_major)));
addend_handle = client_->TransferToServer(*addend_lit).ConsumeValueOrDie();
@@ -476,14 +478,14 @@ class NonsquareMatrixDot : public DotOperationTest {
void TestImpl(bool lhs_row_major, bool rhs_row_major) {
auto lhs_handle =
client_
- ->TransferToServer(*Literal::CreateFromArrayWithLayout<T>(
+ ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout<T>(
{{1.0f, 2.0f, 3.0f}, {3.0f, -4.0f, -1.0f}},
LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(lhs_row_major))))
.ConsumeValueOrDie();
auto rhs_handle =
client_
- ->TransferToServer(*Literal::CreateFromArrayWithLayout<T>(
+ ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout<T>(
{{1.0f, 6.0f}, {2.0f, 3.0f}, {7.0f, -4.0f}},
LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(rhs_row_major))))
@@ -510,12 +512,12 @@ XLA_TYPED_TEST(NonsquareMatrixDot, TestTT) { this->TestImpl(true, true); }
XLA_TEST_F(DotOperationTest, MatrixVectorC64) {
auto lhs_handle =
client_
- ->TransferToServer(*Literal::CreateR2WithLayout<complex64>(
+ ->TransferToServer(*LiteralUtil::CreateR2WithLayout<complex64>(
{{1.0, 2.0, 3.0, -4.0}}, LayoutUtil::MakeLayout({1, 0})))
.ConsumeValueOrDie();
auto rhs_handle =
client_
- ->TransferToServer(*Literal::CreateR2WithLayout<complex64>(
+ ->TransferToServer(*LiteralUtil::CreateR2WithLayout<complex64>(
{{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}, {-4.0, 4.0}},
LayoutUtil::MakeLayout({1, 0})))
.ConsumeValueOrDie();
@@ -583,7 +585,7 @@ XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) {
Reshape(out_flat, {0, 1, 2}, {2, 2, 2, 2});
auto x_data = this->client_
- ->TransferToServer(*Literal::CreateR4FromArray4D<T>(
+ ->TransferToServer(*LiteralUtil::CreateR4FromArray4D<T>(
{{{{1000.0f, 100.0f}, {10.0f, 1.0f}},
{{2000.0f, 200.0f}, {20.0f, 2.0f}}},
{{{3000.0f, 300.0f}, {30.0f, 3.0f}},
@@ -591,7 +593,7 @@ XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) {
.ConsumeValueOrDie();
auto y_data =
this->client_
- ->TransferToServer(*Literal::CreateR4FromArray4D<T>(
+ ->TransferToServer(*LiteralUtil::CreateR4FromArray4D<T>(
{{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
{{{11.0f, 22.0f}, {33.0f, 44.0f}},
{{55.0f, 66.0f}, {77.0f, 88.0f}}}}))
@@ -629,13 +631,13 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, GeneralMatMul) {
auto x_data =
this->client_
- ->TransferToServer(*Literal::CreateR3FromArray3D<T>(
+ ->TransferToServer(*LiteralUtil::CreateR3FromArray3D<T>(
{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}))
.ConsumeValueOrDie();
auto y_data =
this->client_
- ->TransferToServer(*Literal::CreateR3FromArray3D<T>(
+ ->TransferToServer(*LiteralUtil::CreateR3FromArray3D<T>(
{{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}}))
.ConsumeValueOrDie();
@@ -664,15 +666,17 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, TransposeFolding) {
}
auto lhs_handle =
this->client_
- ->TransferToServer(*Literal::CreateR2FromArray2DWithLayout<T>(
- *lhs, LayoutUtil::MakeLayout(
- MinorToMajorForIsRowMajor(row_major))))
+ ->TransferToServer(
+ *LiteralUtil::CreateR2FromArray2DWithLayout<T>(
+ *lhs, LayoutUtil::MakeLayout(
+ MinorToMajorForIsRowMajor(row_major))))
.ConsumeValueOrDie();
auto rhs_handle =
this->client_
- ->TransferToServer(*Literal::CreateR2FromArray2DWithLayout<T>(
- *rhs, LayoutUtil::MakeLayout(
- MinorToMajorForIsRowMajor(row_major))))
+ ->TransferToServer(
+ *LiteralUtil::CreateR2FromArray2DWithLayout<T>(
+ *rhs, LayoutUtil::MakeLayout(
+ MinorToMajorForIsRowMajor(row_major))))
.ConsumeValueOrDie();
XlaBuilder builder(this->TestName());
@@ -733,15 +737,15 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64,
TF_ASSERT_OK_AND_ASSIGN(
auto arg_0_value,
this->client_->TransferToServer(
- *Literal::CreateR2FromArray2D<T>(*arg_0_value_array)));
+ *LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array)));
TF_ASSERT_OK_AND_ASSIGN(
auto arg_1_value,
this->client_->TransferToServer(
- *Literal::CreateR2FromArray2D<T>(*arg_1_value_array)));
+ *LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array)));
TF_ASSERT_OK_AND_ASSIGN(
auto arg_2_value,
this->client_->TransferToServer(
- *Literal::CreateR2FromArray2D<T>(*arg_2_value_array)));
+ *LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array)));
Array2D<T> expected({{53.0f, 74.0f}, {45.0f, 66.0f}});
this->template ComputeAndCompareR2<T>(
@@ -782,15 +786,15 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64,
TF_ASSERT_OK_AND_ASSIGN(
auto arg_0_value,
this->client_->TransferToServer(
- *Literal::CreateR2FromArray2D<T>(*arg_0_value_array)));
+ *LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array)));
TF_ASSERT_OK_AND_ASSIGN(
auto arg_1_value,
this->client_->TransferToServer(
- *Literal::CreateR2FromArray2D<T>(*arg_1_value_array)));
+ *LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array)));
TF_ASSERT_OK_AND_ASSIGN(
auto arg_2_value,
this->client_->TransferToServer(
- *Literal::CreateR2FromArray2D<T>(*arg_2_value_array)));
+ *LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array)));
Array2D<T> expected({{38.0f, 36.0f}, {93.0f, 91.0f}});
this->template ComputeAndCompareR2<T>(
diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
index f3c258a4d4..88ac96d6b0 100644
--- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
@@ -124,11 +124,11 @@ class DynamicSliceTest : public ClientLibraryTestBase {
// vector<bool> is special so that it cannot be an ArraySlice<bool>, which
// is what the code below wants. So instead we do this.
Literal input_values =
- std::move(*Literal::CreateR1(input_values_int)
+ std::move(*LiteralUtil::CreateR1(input_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal expected_values =
- std::move(*Literal::CreateR1(expected_values_int)
+ std::move(*LiteralUtil::CreateR1(expected_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
@@ -150,11 +150,11 @@ class DynamicSliceTest : public ClientLibraryTestBase {
const std::vector<int64>& slice_sizes,
const Array2D<int>& expected_values_int) {
Literal input_values =
- std::move(*Literal::CreateR2FromArray2D(input_values_int)
+ std::move(*LiteralUtil::CreateR2FromArray2D(input_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal expected_values =
- std::move(*Literal::CreateR2FromArray2D(expected_values_int)
+ std::move(*LiteralUtil::CreateR2FromArray2D(expected_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
@@ -176,11 +176,11 @@ class DynamicSliceTest : public ClientLibraryTestBase {
const std::vector<int64>& slice_sizes,
const Array3D<int>& expected_values_int) {
Literal input_values =
- std::move(*Literal::CreateR3FromArray3D(input_values_int)
+ std::move(*LiteralUtil::CreateR3FromArray3D(input_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal expected_values =
- std::move(*Literal::CreateR3FromArray3D(expected_values_int)
+ std::move(*LiteralUtil::CreateR3FromArray3D(expected_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
@@ -202,18 +202,28 @@ XLA_TEST_F(DynamicSliceTest, Int32R1) { TestR1<int32, int32>(); }
XLA_TEST_F(DynamicSliceTest, Int32R1OOB) { TestR1OOB<int32, int32>(); }
XLA_TEST_F(DynamicSliceTest, Int64R1) { TestR1<int64, float>(); }
XLA_TEST_F(DynamicSliceTest, UInt64R1) { TestR1<uint64, float>(); }
+XLA_TEST_F(DynamicSliceTest, UInt32R1OOB) {
+ RunR1<uint32, int32>({0, 1, 2, 3, 4}, {2147483648u}, {2}, {3, 4});
+}
XLA_TEST_F(DynamicSliceTest, Int32R2BF16) { TestR2<int32, bfloat16>(); }
XLA_TEST_F(DynamicSliceTest, Int32R2) { TestR2<int32, int32>(); }
XLA_TEST_F(DynamicSliceTest, Int32R2OOB) { TestR2OOB<int32, int32>(); }
XLA_TEST_F(DynamicSliceTest, Int64R2) { TestR2<int64, float>(); }
XLA_TEST_F(DynamicSliceTest, UInt64R2) { TestR2<uint64, int32>(); }
+XLA_TEST_F(DynamicSliceTest, UInt32R2OOB) {
+ RunR2<uint32, int32>({{0, 1}, {2, 3}}, {2147483648u, 0}, {1, 1}, {{2}});
+}
XLA_TEST_F(DynamicSliceTest, Int32R3BF16) { TestR3<int32, bfloat16>(); }
XLA_TEST_F(DynamicSliceTest, Int32R3) { TestR3<int32, float>(); }
XLA_TEST_F(DynamicSliceTest, Int32R3OOB) { TestR3OOB<int32, float>(); }
XLA_TEST_F(DynamicSliceTest, Int64R3) { TestR3<int64, float>(); }
XLA_TEST_F(DynamicSliceTest, UInt64R3) { TestR3<uint64, float>(); }
+XLA_TEST_F(DynamicSliceTest, UInt32R3OOB) {
+ RunR3<uint32, int32>({{{0, 1}, {2, 3}}, {{4, 5}, {6, 7}}},
+ {2147483648u, 0, 2147483648u}, {1, 1, 1}, {{{5}}});
+}
XLA_TEST_F(DynamicSliceTest, Int32R1Pred) {
// Slice at dimension start.
@@ -349,15 +359,15 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
void RunR0(int input_value_int, int update_value_int,
const std::vector<IndexT> slice_starts, int expected_value_int) {
Literal input_value =
- std::move(*Literal::CreateR0(input_value_int)
+ std::move(*LiteralUtil::CreateR0(input_value_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal update_value =
- std::move(*Literal::CreateR0(update_value_int)
+ std::move(*LiteralUtil::CreateR0(update_value_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal expected_value =
- std::move(*Literal::CreateR0(expected_value_int)
+ std::move(*LiteralUtil::CreateR0(expected_value_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
@@ -380,15 +390,15 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
const std::vector<IndexT> slice_starts,
tensorflow::gtl::ArraySlice<int> expected_values_int) {
Literal input_values =
- std::move(*Literal::CreateR1(input_values_int)
+ std::move(*LiteralUtil::CreateR1(input_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal update_values =
- std::move(*Literal::CreateR1(update_values_int)
+ std::move(*LiteralUtil::CreateR1(update_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal expected_values =
- std::move(*Literal::CreateR1(expected_values_int)
+ std::move(*LiteralUtil::CreateR1(expected_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
@@ -411,15 +421,15 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
const std::vector<IndexT> slice_starts,
const Array2D<int>& expected_values_int) {
Literal input_values =
- std::move(*Literal::CreateR2FromArray2D(input_values_int)
+ std::move(*LiteralUtil::CreateR2FromArray2D(input_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal update_values =
- std::move(*Literal::CreateR2FromArray2D(update_values_int)
+ std::move(*LiteralUtil::CreateR2FromArray2D(update_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal expected_values =
- std::move(*Literal::CreateR2FromArray2D(expected_values_int)
+ std::move(*LiteralUtil::CreateR2FromArray2D(expected_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
@@ -442,15 +452,15 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
const std::vector<IndexT> slice_starts,
const Array3D<int>& expected_values_int) {
Literal input_values =
- std::move(*Literal::CreateR3FromArray3D(input_values_int)
+ std::move(*LiteralUtil::CreateR3FromArray3D(input_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal update_values =
- std::move(*Literal::CreateR3FromArray3D(update_values_int)
+ std::move(*LiteralUtil::CreateR3FromArray3D(update_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal expected_values =
- std::move(*Literal::CreateR3FromArray3D(expected_values_int)
+ std::move(*LiteralUtil::CreateR3FromArray3D(expected_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
@@ -520,7 +530,7 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
template <typename NativeT>
void DumpArray(const string& name, const Array3D<NativeT> values) {
std::unique_ptr<Literal> literal =
- Literal::CreateR3FromArray3D<NativeT>(values);
+ LiteralUtil::CreateR3FromArray3D<NativeT>(values);
LOG(INFO) << name << ":" << literal->ToString();
}
};
@@ -530,21 +540,32 @@ XLA_TEST_F(DynamicUpdateSliceTest, Int32R0) { TestR0<int32, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest, Int64R0) { TestR0<int64, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest, UInt64R0) { TestR0<uint64, float>(); }
-// TODO(b/71820067): The CPU parallel backend failed for this on 2018-01-10.
XLA_TEST_F(DynamicUpdateSliceTest, Int32R1BF16) { TestR1<int32, bfloat16>(); }
XLA_TEST_F(DynamicUpdateSliceTest, Int32R1) { TestR1<int32, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest, Int64R1) { TestR1<int64, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest, UInt64R1) { TestR1<uint64, float>(); }
+XLA_TEST_F(DynamicUpdateSliceTest, UInt32R1OOB) {
+ RunR1<uint32, int32>({0, 1, 2, 3, 4}, {5, 6}, {2147483648u}, {0, 1, 2, 5, 6});
+}
XLA_TEST_F(DynamicUpdateSliceTest, Int32R2BF16) { TestR2<int32, bfloat16>(); }
XLA_TEST_F(DynamicUpdateSliceTest, Int32R2) { TestR2<int32, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest, Int64R2) { TestR2<int64, int64>(); }
XLA_TEST_F(DynamicUpdateSliceTest, UInt64R2) { TestR2<uint64, int32>(); }
+XLA_TEST_F(DynamicUpdateSliceTest, UInt32R2OOB) {
+ RunR2<uint32, int32>({{0, 1}, {2, 3}}, {{4}}, {2147483648u, 0},
+ {{0, 1}, {4, 3}});
+}
XLA_TEST_F(DynamicUpdateSliceTest, Int32R3BF16) { TestR3<int32, bfloat16>(); }
XLA_TEST_F(DynamicUpdateSliceTest, Int32R3) { TestR3<int32, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest, Int64R3) { TestR3<int64, int64>(); }
XLA_TEST_F(DynamicUpdateSliceTest, UInt64R3) { TestR3<uint64, uint64>(); }
+XLA_TEST_F(DynamicUpdateSliceTest, UInt32R3OOB) {
+ RunR3<uint32, int32>({{{0, 1}, {2, 3}}, {{4, 5}, {6, 7}}}, {{{8}}},
+ {2147483648u, 0, 2147483648u},
+ {{{0, 1}, {2, 3}}, {{4, 8}, {6, 7}}});
+}
XLA_TEST_F(DynamicUpdateSliceTest, Int32OOBBF16) { TestOOB<int32, bfloat16>(); }
XLA_TEST_F(DynamicUpdateSliceTest, Int32OOB) { TestOOB<int32, float>(); }
@@ -695,7 +716,7 @@ void BM_DynamicSlice(int num_iters) {
XlaBuilder builder("DynamicSlice");
// Create input as a constant: shape [1, 2, 3, 4]
- auto input_literal = Literal::CreateR4(
+ auto input_literal = LiteralUtil::CreateR4(
{{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
{{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}});
auto input = ConstantLiteral(&builder, *input_literal);
@@ -715,7 +736,7 @@ void BM_DynamicSlice(int num_iters) {
start_indices_shape, &allocator, /*device_ordinal=*/0)
.ConsumeValueOrDie();
- auto start_indices_literal = Literal::CreateR1<int32>({0, 1, 2, 3});
+ auto start_indices_literal = LiteralUtil::CreateR1<int32>({0, 1, 2, 3});
auto stream =
client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie();
ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice(
diff --git a/tensorflow/compiler/xla/tests/execution_profile_test.cc b/tensorflow/compiler/xla/tests/execution_profile_test.cc
index ddc6a7db18..ebba13c5b3 100644
--- a/tensorflow/compiler/xla/tests/execution_profile_test.cc
+++ b/tensorflow/compiler/xla/tests/execution_profile_test.cc
@@ -31,7 +31,7 @@ XLA_TEST_F(ExecutionProfileTest, ExecuteWithExecutionProfile) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> input,
client_->TransferToServer(
- *Literal::CreateR2F32Linspace(1e0, 1e5, 256, 256)));
+ *LiteralUtil::CreateR2F32Linspace(1e0, 1e5, 256, 256)));
XlaBuilder b(TestName() + ".add");
Dot(Parameter(&b, 0, shape, "param_0"), Parameter(&b, 1, shape, "param_1"));
diff --git a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc
index 74cf8b213e..86bfaea4ef 100644
--- a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc
+++ b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc
@@ -39,7 +39,7 @@ class ExhaustiveF32ElementwiseOpTest
XlaBuilder builder(TestName());
std::unique_ptr<Literal> input_literal =
- Literal::CreateFromDimensions(F32, {input_size});
+ LiteralUtil::CreateFromDimensions(F32, {input_size});
for (int64 i = begin; i < end; i++) {
if (i >= known_incorrect_range.first &&
i < known_incorrect_range.second) {
diff --git a/tensorflow/compiler/xla/tests/filecheck.cc b/tensorflow/compiler/xla/tests/filecheck.cc
index 93d1c921c4..dcb469087e 100644
--- a/tensorflow/compiler/xla/tests/filecheck.cc
+++ b/tensorflow/compiler/xla/tests/filecheck.cc
@@ -76,6 +76,11 @@ StatusOr<bool> RunFileCheck(const string& input, const string& pattern) {
XLA_LOG_LINES(tensorflow::WARNING, input);
LOG(WARNING) << "FileCheck pattern was:";
XLA_LOG_LINES(tensorflow::WARNING, pattern);
+ } else if (!standard_error.empty()) {
+ LOG(INFO) << "FileCheck stderr:";
+ XLA_LOG_LINES(tensorflow::INFO, standard_error);
+ LOG(INFO) << "FileCheck input was:";
+ XLA_LOG_LINES(tensorflow::INFO, input);
}
return succeeded;
}
diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc
index f7f9a87413..dc64477935 100644
--- a/tensorflow/compiler/xla/tests/fusion_test.cc
+++ b/tensorflow/compiler/xla/tests/fusion_test.cc
@@ -26,7 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -90,7 +90,7 @@ class FusionTest : public HloTestBase {
HloInstruction* hlos[4];
for (int i = 0; i < Arity; ++i) {
hlos[i + 1] = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2FromArray2D(operand_data[i])));
+ LiteralUtil::CreateR2FromArray2D(operand_data[i])));
}
auto answer_shape =
ShapeUtil::MakeShape(prim_type, {test_width, test_height});
@@ -116,7 +116,7 @@ class FusionTest : public HloTestBase {
ArraySlice<HloInstruction*>(hlos, 0, Arity + 1),
HloInstruction::FusionKind::kLoop);
- auto expected = Literal::CreateR2FromArray2D(answer_data);
+ auto expected = LiteralUtil::CreateR2FromArray2D(answer_data);
auto actual = ExecuteAndTransfer(std::move(hlo_module), {});
if (primitive_util::IsFloatingPointType(prim_type)) {
EXPECT_TRUE(LiteralTestUtil::Near(*expected, *actual, ErrorSpec(1e-4)));
@@ -187,27 +187,28 @@ XLA_TEST_F(FusionTest, Test) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1.0}, {2.0}, {3.0}})));
+ LiteralUtil::CreateR2<float>({{1.0}, {2.0}, {3.0}})));
auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{-1.0}, {-1.0}, {-1.0}})));
+ LiteralUtil::CreateR2<float>({{-1.0}, {-1.0}, {-1.0}})));
auto add2 = builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(F32, {3, 1}), HloOpcode::kAdd, const0, const1));
auto reshape3 = builder.AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(F32, {1, 3}), add2, {1, 0}));
auto const4 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1.62, 2.72, 3.14}})));
+ LiteralUtil::CreateR2<float>({{1.62, 2.72, 3.14}})));
auto concat5 = builder.AddInstruction(HloInstruction::CreateConcatenate(
ShapeUtil::MakeShape(F32, {2, 3}), {reshape3, const4}, 0));
auto const6 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}})));
+ LiteralUtil::CreateR2<float>({{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}})));
auto negate7 = builder.AddInstruction(HloInstruction::CreateUnary(
ShapeUtil::MakeShape(F32, {2, 3}), HloOpcode::kNegate, const6));
auto add8 = builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(F32, {2, 3}), HloOpcode::kAdd, concat5, negate7));
auto const9 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}})));
- auto const10 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<bool>({{true, false, true}, {false, true, false}})));
+ LiteralUtil::CreateR2<float>({{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}})));
+ auto const10 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2<bool>(
+ {{true, false, true}, {false, true, false}})));
auto select11 = builder.AddInstruction(
HloInstruction::CreateTernary(ShapeUtil::MakeShape(F32, {2, 3}),
HloOpcode::kSelect, const10, add8, const9));
@@ -223,7 +224,7 @@ XLA_TEST_F(FusionTest, Test) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Near(
- *Literal::CreateR2<float>({{0.5}, {2.72}}),
+ *LiteralUtil::CreateR2<float>({{0.5}, {2.72}}),
*ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
}
@@ -234,11 +235,11 @@ XLA_TEST_F(FusionTest, Parameter) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1.0, 2.0, 3.0}})));
+ LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}})));
auto copy1 = builder.AddInstruction(HloInstruction::CreateUnary(
ShapeUtil::MakeShape(F32, {1, 3}), HloOpcode::kCopy, const0));
auto const2 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{-2.0, -2.0, -2.0}})));
+ LiteralUtil::CreateR2<float>({{-2.0, -2.0, -2.0}})));
// add3 = copy1 + const2 = const0 + const2 = {1,2,3} + {-2,-2,-2} = {-1,0,+1}
auto add3 = builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(F32, {1, 3}), HloOpcode::kAdd, copy1, const2));
@@ -249,7 +250,7 @@ XLA_TEST_F(FusionTest, Parameter) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Near(
- *Literal::CreateR2<float>({{-1.0, 0.0, 1.0}}),
+ *LiteralUtil::CreateR2<float>({{-1.0, 0.0, 1.0}}),
*ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
}
@@ -270,7 +271,7 @@ XLA_TEST_F(FusionTest, RandomizedParallelPartition) {
auto hlo_module = CreateNewModule();
auto two = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto x =
builder.AddInstruction(HloInstruction::CreateBroadcast(shape, two, {}));
auto y = builder.AddInstruction(
@@ -293,9 +294,9 @@ XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const_vector = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({1.0, 2.0, 3.0})));
+ LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0})));
auto const_array = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}})));
+ LiteralUtil::CreateR2<float>({{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}})));
auto broadcast = builder.AddInstruction(
HloInstruction::CreateBroadcast(const_array->shape(), const_vector, {1}));
// add2 = broadcast(const_vector) + const_array
@@ -309,7 +310,7 @@ XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Near(
- *Literal::CreateR2<float>({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}),
+ *LiteralUtil::CreateR2<float>({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}),
*ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
}
@@ -317,14 +318,14 @@ XLA_TEST_F(FusionTest, ReshapeToScalar) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto single_element_array = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2<int32>({{5}})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2<int32>({{5}})));
auto reshape = builder.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(S32, {}), single_element_array));
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR0<int32>(5),
+ LiteralTestUtil::Equal(*LiteralUtil::CreateR0<int32>(5),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
@@ -332,14 +333,14 @@ XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}})));
+ LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}})));
auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(S32, {1, 2, 3}), const0));
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Equal(
- *Literal::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}}),
+ *LiteralUtil::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}}),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
@@ -347,14 +348,14 @@ XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}})));
+ LiteralUtil::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}})));
auto reshape1 = builder.AddInstruction(
HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {3, 2}), const0));
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Equal(
- *Literal::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}}),
+ *LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}}),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
@@ -362,14 +363,14 @@ XLA_TEST_F(FusionTest, Reshape_1by1by1_) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR3<int32>({{{7}}})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR3<int32>({{{7}}})));
auto reshape1 = builder.AddInstruction(
HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), const0));
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR0<int32>(7),
+ LiteralTestUtil::Equal(*LiteralUtil::CreateR0<int32>(7),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
@@ -377,14 +378,14 @@ XLA_TEST_F(FusionTest, Reshape__1by1by1) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(7)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(7)));
auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(S32, {1, 1, 1}), const0));
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR3<int32>({{{7}}}),
+ LiteralTestUtil::Equal(*LiteralUtil::CreateR3<int32>({{{7}}}),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
@@ -392,14 +393,14 @@ XLA_TEST_F(FusionTest, Reshape__) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(7)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(7)));
auto reshape1 = builder.AddInstruction(
HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), const0));
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR0<int32>(7),
+ LiteralTestUtil::Equal(*LiteralUtil::CreateR0<int32>(7),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
@@ -407,14 +408,14 @@ XLA_TEST_F(FusionTest, Reshape_3by3_3by3) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
auto reshape1 = builder.AddInstruction(
HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {3, 3}), const0));
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Equal(
- *Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}),
+ *LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
@@ -422,14 +423,14 @@ XLA_TEST_F(FusionTest, Transpose_2by3) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}})));
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}})));
auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(S32, {3, 2}), const0, {1, 0}));
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Equal(
- *Literal::CreateR2<int32>({{1, 4}, {2, 5}, {3, 6}}),
+ *LiteralUtil::CreateR2<int32>({{1, 4}, {2, 5}, {3, 6}}),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
@@ -437,14 +438,14 @@ XLA_TEST_F(FusionTest, Transpose_3by3) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(S32, {3, 3}), const0, {1, 0}));
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Equal(
- *Literal::CreateR2<int32>({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}),
+ *LiteralUtil::CreateR2<int32>({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
@@ -452,7 +453,7 @@ XLA_TEST_F(FusionTest, Reverse) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 3})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({1, 2, 3})));
auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse(
ShapeUtil::MakeShape(S32, {3}), const0, {0}));
hlo_module->AddEntryComputation(builder.Build())
@@ -460,7 +461,7 @@ XLA_TEST_F(FusionTest, Reverse) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR1<int32>({3, 2, 1}),
+ LiteralTestUtil::Equal(*LiteralUtil::CreateR1<int32>({3, 2, 1}),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
@@ -468,7 +469,7 @@ XLA_TEST_F(FusionTest, ReverseNegate) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 3})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({1, 2, 3})));
auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse(
ShapeUtil::MakeShape(S32, {3}), const0, {0}));
auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary(
@@ -478,7 +479,7 @@ XLA_TEST_F(FusionTest, ReverseNegate) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR1<int32>({-3, -2, -1}),
+ LiteralTestUtil::Equal(*LiteralUtil::CreateR1<int32>({-3, -2, -1}),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
@@ -486,7 +487,7 @@ XLA_TEST_F(FusionTest, BroadcastNegate) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(1)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
auto broadcast1 = builder.AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::MakeShape(S32, {2}), const0, {}));
auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary(
@@ -496,15 +497,15 @@ XLA_TEST_F(FusionTest, BroadcastNegate) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR1<int32>({-1, -1}),
+ LiteralTestUtil::Equal(*LiteralUtil::CreateR1<int32>({-1, -1}),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, SliceNegate) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
- auto const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 3, 4})));
+ auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<int32>({1, 2, 3, 4})));
auto slice1 = builder.AddInstruction(HloInstruction::CreateSlice(
ShapeUtil::MakeShape(S32, {2}), const0, {0}, {4}, {2}));
auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary(
@@ -514,17 +515,17 @@ XLA_TEST_F(FusionTest, SliceNegate) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR1<int32>({-1, -3}),
+ LiteralTestUtil::Equal(*LiteralUtil::CreateR1<int32>({-1, -3}),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, DynamicSliceNegate) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
- auto const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 3, 4})));
+ auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<int32>({1, 2, 3, 4})));
auto const1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int32>({1})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({1})));
auto dynamic_slice2 =
builder.AddInstruction(HloInstruction::CreateDynamicSlice(
ShapeUtil::MakeShape(S32, {2}), const0, const1, {2}));
@@ -536,15 +537,15 @@ XLA_TEST_F(FusionTest, DynamicSliceNegate) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR1<int32>({-2, -3}),
+ LiteralTestUtil::Equal(*LiteralUtil::CreateR1<int32>({-2, -3}),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, ReshapeNegate) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
- auto const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 3, 4})));
+ auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<int32>({1, 2, 3, 4})));
auto reshape1 = builder.AddInstruction(
HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {2, 2}), const0));
auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary(
@@ -553,16 +554,16 @@ XLA_TEST_F(FusionTest, ReshapeNegate) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reshape1},
HloInstruction::FusionKind::kLoop);
- EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{-1, -2}, {-3, -4}}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ EXPECT_TRUE(LiteralTestUtil::Equal(
+ *LiteralUtil::CreateR2<int32>({{-1, -2}, {-3, -4}}),
+ *ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, TransposeNegate) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<int32>({{1, 2}, {3, 4}})));
+ LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}})));
auto transpose1 = builder.AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(S32, {2, 2}), const0, {1, 0}));
auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary(
@@ -571,9 +572,9 @@ XLA_TEST_F(FusionTest, TransposeNegate) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, transpose1},
HloInstruction::FusionKind::kLoop);
- EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{-1, -3}, {-2, -4}}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ EXPECT_TRUE(LiteralTestUtil::Equal(
+ *LiteralUtil::CreateR2<int32>({{-1, -3}, {-2, -4}}),
+ *ExecuteAndTransfer(std::move(hlo_module), {})));
}
std::unique_ptr<HloComputation> MakeReduceTestComputation() {
@@ -591,10 +592,10 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) {
auto hlo_module = CreateNewModule();
auto builder = HloComputation::Builder(TestName());
- auto const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 4, 8})));
+ auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<int32>({1, 2, 4, 8})));
auto const1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
auto reduce2 = builder.AddInstruction(HloInstruction::CreateReduce(
ShapeUtil::MakeShape(S32, {}), const0, const1, {0},
hlo_module->AddEmbeddedComputation(MakeReduceTestComputation())));
@@ -603,7 +604,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR0<int32>(15),
+ LiteralTestUtil::Equal(*LiteralUtil::CreateR0<int32>(15),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
@@ -611,10 +612,10 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) {
auto hlo_module = CreateNewModule();
auto builder = HloComputation::Builder(TestName());
- auto const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 4, 8})));
+ auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<int32>({1, 2, 4, 8})));
auto const1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
auto reduce2 = builder.AddInstruction(HloInstruction::CreateReduce(
ShapeUtil::MakeShape(S32, {}), const0, const1, {0},
hlo_module->AddEmbeddedComputation(MakeReduceTestComputation())));
@@ -625,7 +626,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR0<int32>(-15),
+ LiteralTestUtil::Equal(*LiteralUtil::CreateR0<int32>(-15),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
@@ -633,9 +634,9 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<int32>({{2, 3, 5}, {7, 11, 13}, {17, 19, 23}})));
+ LiteralUtil::CreateR2<int32>({{2, 3, 5}, {7, 11, 13}, {17, 19, 23}})));
auto const1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(1)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
Window window;
ASSERT_TRUE(
tensorflow::protobuf::TextFormat::ParseFromString("dimensions:{\n"
@@ -675,7 +676,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Equal(
- *Literal::CreateR2<int32>({{462, 2145}, {24871, 62491}}),
+ *LiteralUtil::CreateR2<int32>({{462, 2145}, {24871, 62491}}),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
@@ -687,9 +688,9 @@ XLA_TEST_F(FusionTest, SharedConstant) {
auto builder = HloComputation::Builder(TestName());
auto const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int32>({0})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({0})));
auto const1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int32>({2})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({2})));
auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, const0));
auto add2 = builder.AddInstruction(HloInstruction::CreateBinary(
@@ -711,7 +712,7 @@ XLA_TEST_F(FusionTest, SharedConstant) {
EXPECT_EQ(entry_comp->root_instruction()->fused_instruction_count(), 6);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR1<int32>({8}),
+ LiteralTestUtil::Equal(*LiteralUtil::CreateR1<int32>({8}),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
@@ -784,7 +785,7 @@ ENTRY main {
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR2<float>({{0., 0.}, {1., 0.}});
+ LiteralUtil::CreateR2<float>({{0., 0.}, {1., 0.}});
HloModuleConfig config;
config.set_debug_options(GetDebugOptionsForTest());
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
@@ -794,7 +795,7 @@ ENTRY main {
test_runner_.Execute(std::move(module), {operand.get()},
/*run_hlo_passes=*/false));
EXPECT_TRUE(LiteralTestUtil::Equal(
- *Literal::CreateR3<float>({{{0.}, {0.76159415595}}, {{0.}, {0.}}}),
+ *LiteralUtil::CreateR3<float>({{{0.}, {0.76159415595}}, {{0.}, {0.}}}),
*result));
}
@@ -838,19 +839,19 @@ void BM_ParallelFusion(int num_iters) {
// Transfer literals to device.
auto param0_literal =
- Literal::CreateR2F32Linspace(1.0, 2.0, param0_dim0, param0_dim1);
+ LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param0_dim0, param0_dim1);
ScopedShapedBuffer buffer0 =
client->LiteralToShapedBuffer(*param0_literal, device_ordinal)
.ConsumeValueOrDie();
auto param1_literal =
- Literal::CreateR2F32Linspace(1.0, 2.0, param1_dim0, param1_dim1);
+ LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param1_dim0, param1_dim1);
ScopedShapedBuffer buffer1 =
client->LiteralToShapedBuffer(*param1_literal, device_ordinal)
.ConsumeValueOrDie();
auto param2_literal =
- Literal::CreateR2F32Linspace(1.0, 2.0, param2_dim0, param2_dim1);
+ LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param2_dim0, param2_dim1);
ScopedShapedBuffer buffer2 =
client->LiteralToShapedBuffer(*param2_literal, device_ordinal)
.ConsumeValueOrDie();
diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc
index b8404826b1..c5ca64fa3f 100644
--- a/tensorflow/compiler/xla/tests/gather_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc
@@ -22,9 +22,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
-// NB! TODO(b/74360564): These tests do not test out of bounds behavior since
-// that hasn't been specced yet.
-
namespace xla {
namespace {
@@ -63,8 +60,9 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2});
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> gather_indices =
+ LiteralUtil::CreateR1<int32>({0, 2});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -84,8 +82,9 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2});
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> gather_indices =
+ LiteralUtil::CreateR1<int32>({0, 2});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -105,9 +104,9 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> gather_indices =
- Literal::CreateR2<int32>({{0, 2}, {2, 1}});
+ LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -127,9 +126,9 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> gather_indices =
- Literal::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}});
+ LiteralUtil::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -149,9 +148,9 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> gather_indices =
- Literal::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}});
+ LiteralUtil::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -171,11 +170,11 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
- {{-4, 4}, {-5, 5}, {-6, 6}}, //
- {{-7, 7}, {-8, 8}, {-9, 9}}});
+ LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
+ {{-4, 4}, {-5, 5}, {-6, 6}}, //
+ {{-7, 7}, {-8, 8}, {-9, 9}}});
std::unique_ptr<Literal> gather_indices =
- Literal::CreateR2<int32>({{0, 0}, {1, 0}});
+ LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -195,11 +194,11 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
- {{-4, 4}, {-5, 5}, {-6, 6}}, //
- {{-7, 7}, {-8, 8}, {-9, 9}}});
+ LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
+ {{-4, 4}, {-5, 5}, {-6, 6}}, //
+ {{-7, 7}, {-8, 8}, {-9, 9}}});
std::unique_ptr<Literal> gather_indices =
- Literal::CreateR2<int32>({{0, 0}, {1, 0}});
+ LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -219,8 +218,9 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({1, 1});
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> gather_indices =
+ LiteralUtil::CreateR1<int32>({1, 1});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -240,9 +240,9 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> gather_indices =
- Literal::CreateR2<int32>({{2, 1}, {1, 1}});
+ LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -261,18 +261,15 @@ ENTRY main {
window_bounds={1, 0}
}
)";
- std::unique_ptr<Literal> operand = Literal::CreateR2<int32>({{}, {}, {}});
- std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2});
+ std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
+ std::unique_ptr<Literal> gather_indices =
+ LiteralUtil::CreateR1<int32>({0, 2});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
XLA_TEST_F(GatherOperationTest, OutOfBoundsIndex) {
// Out of bounds indices must not crash, and the indices in range should
// produce the same values across all backends.
- //
- // TODO(b/74360564): Once we have a well defined semantics for OOB accesses,
- // we should get rid of the mask and check that backends produce the same
- // value for OOB indices too.
const string hlo_text = R"(
HloModule BatchDynamicSlice
@@ -286,29 +283,45 @@ ENTRY main {
gather_dims_to_operand_dims={0,1},
index_vector_dim=1,
window_bounds={1,1}
- gather_reshaped = s32[6]{0} reshape(gather)
- in_bounds_mask = s32[6]{0} parameter(2)
- ROOT result = s32[6]{0} multiply(gather_reshaped, in_bounds_mask)
+ ROOT result = s32[6]{0} reshape(gather)
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = Literal::CreateR2<int32>(
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR2<int32>(
{{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}});
- std::unique_ptr<Literal> in_bounds_mask =
- Literal::CreateR1<int32>({0, 1, 1, 0, 0, 1});
+ RunTest(hlo_text, operand.get(), gather_indices.get());
+}
+
+XLA_TEST_F(GatherOperationTest, OutOfBoundsUnsignedIndex) {
+ // Out of bounds indices must not crash, and the indices in range should
+ // produce the same values across all backends.
- RunTest(hlo_text,
- {operand.get(), gather_indices.get(), in_bounds_mask.get()});
+ const string hlo_text = R"(
+HloModule BatchDynamicSlice
+
+ENTRY main {
+ operand = s32[3,3]{1,0} parameter(0)
+ indices = u32[6,2]{1,0} parameter(1)
+ gather = s32[6,1,1]{2,1,0} gather(operand, indices),
+ output_window_dims={1,2},
+ elided_window_dims={},
+ gather_dims_to_operand_dims={0,1},
+ index_vector_dim=1,
+ window_bounds={1,1}
+ ROOT result = s32[6]{0} reshape(gather)
+}
+)";
+ std::unique_ptr<Literal> operand =
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR2<uint32>(
+ {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483648u, 1}, {1, 2}});
+ RunTest(hlo_text, operand.get(), gather_indices.get());
}
XLA_TEST_F(GatherOperationTest, NegativeIndex) {
// Negative indices must not crash, and the indices in range should produce
// the same values across all backends.
- //
- // TODO(b/74360564): Once we have a well defined semantics for negative
- // accesses, we should get rid of the mask and check that backends produce the
- // same value for negative indices too.
const string hlo_text = R"(
HloModule BatchDynamicSlice
@@ -322,20 +335,40 @@ ENTRY main {
gather_dims_to_operand_dims={0,1},
index_vector_dim=1,
window_bounds={1,1}
- gather_reshaped = s32[6]{0} reshape(gather)
- in_bounds_mask = s32[6]{0} parameter(2)
- ROOT result = s32[6]{0} multiply(gather_reshaped, in_bounds_mask)
+ ROOT result = s32[6]{0} reshape(gather)
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = Literal::CreateR2<int32>(
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR2<int32>(
{{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}});
- std::unique_ptr<Literal> in_bounds_mask =
- Literal::CreateR1<int32>({0, 1, 1, 0, 0, 1});
+ RunTest(hlo_text, operand.get(), gather_indices.get());
+}
+
+XLA_TEST_F(GatherOperationTest, NegativeIndexIntoUnsignedOperand) {
+ // Negative indices must not crash, and the indices in range should produce
+ // the same values across all backends.
- RunTest(hlo_text,
- {operand.get(), gather_indices.get(), in_bounds_mask.get()});
+ const string hlo_text = R"(
+HloModule BatchDynamicSlice
+
+ENTRY main {
+ operand = u32[3,3]{1,0} parameter(0)
+ indices = s32[6,2]{1,0} parameter(1)
+ gather = u32[6,1,1]{2,1,0} gather(operand, indices),
+ output_window_dims={1,2},
+ elided_window_dims={},
+ gather_dims_to_operand_dims={0,1},
+ index_vector_dim=1,
+ window_bounds={1,1}
+ ROOT result = u32[6]{0} reshape(gather)
+}
+)";
+ std::unique_ptr<Literal> operand =
+ LiteralUtil::CreateR2<uint32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR2<int32>(
+ {{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}});
+ RunTest(hlo_text, operand.get(), gather_indices.get());
}
XLA_TEST_F(GatherOperationTest, OneScalarIndex) {
@@ -353,9 +386,9 @@ ENTRY main {
window_bounds={1,3,2}
}
)";
- std::unique_ptr<Literal> operand = Literal::CreateR3<int32>(
+ std::unique_ptr<Literal> operand = LiteralUtil::CreateR3<int32>(
{{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}});
- std::unique_ptr<Literal> gather_indices = Literal::CreateR0<int32>(1);
+ std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR0<int32>(1);
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -374,8 +407,8 @@ ENTRY main {
window_bounds={1}
}
)";
- std::unique_ptr<Literal> operand = Literal::CreateR1<int32>({1, 2, 3, 4});
- std::unique_ptr<Literal> gather_indices = Literal::CreateR0<int32>(1);
+ std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({1, 2, 3, 4});
+ std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR0<int32>(1);
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -395,8 +428,8 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({});
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR1<int32>({});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -419,8 +452,9 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2});
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> gather_indices =
+ LiteralUtil::CreateR1<int32>({0, 2});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -443,9 +477,9 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> gather_indices =
- Literal::CreateR2<int32>({{0, 2}, {2, 1}});
+ LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -468,9 +502,9 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> gather_indices =
- Literal::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}});
+ LiteralUtil::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -493,11 +527,11 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
- {{-4, 4}, {-5, 5}, {-6, 6}}, //
- {{-7, 7}, {-8, 8}, {-9, 9}}});
+ LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
+ {{-4, 4}, {-5, 5}, {-6, 6}}, //
+ {{-7, 7}, {-8, 8}, {-9, 9}}});
std::unique_ptr<Literal> gather_indices =
- Literal::CreateR2<int32>({{0, 0}, {1, 0}});
+ LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -521,11 +555,11 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
- {{-4, 4}, {-5, 5}, {-6, 6}}, //
- {{-7, 7}, {-8, 8}, {-9, 9}}});
+ LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
+ {{-4, 4}, {-5, 5}, {-6, 6}}, //
+ {{-7, 7}, {-8, 8}, {-9, 9}}});
std::unique_ptr<Literal> gather_indices =
- Literal::CreateR2<int32>({{0, 0}, {1, 0}});
+ LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -548,8 +582,9 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({1, 1});
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> gather_indices =
+ LiteralUtil::CreateR1<int32>({1, 1});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -572,9 +607,9 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> gather_indices =
- Literal::CreateR2<int32>({{2, 1}, {1, 1}});
+ LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -609,12 +644,13 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) {
Gather(operand, indices, dim_numbers, {1, 3});
std::vector<int32> expected = {};
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> operand_arg,
- client_->TransferToServer(*Literal::CreateR2<int32>(
- {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<GlobalData> operand_arg,
+ client_->TransferToServer(
+ *LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> indices_arg,
- client_->TransferToServer(*Literal::CreateR1<int32>({0, 2})));
+ client_->TransferToServer(*LiteralUtil::CreateR1<int32>({0, 2})));
TF_ASSERT_OK_AND_ASSIGN(std::vector<xla::DeviceHandle> devices,
client_->GetDeviceHandles(1));
xla::ExecutionOptions execution_options = CreateDefaultExecutionOptions();
diff --git a/tensorflow/compiler/xla/tests/half_test.cc b/tensorflow/compiler/xla/tests/half_test.cc
index fd85118849..73a47eda72 100644
--- a/tensorflow/compiler/xla/tests/half_test.cc
+++ b/tensorflow/compiler/xla/tests/half_test.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc
index 242cc5db11..b662e83716 100644
--- a/tensorflow/compiler/xla/tests/hlo_test_base.cc
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc
@@ -276,9 +276,10 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
HloComputation* HloTestBase::FindComputation(HloModule* module,
tensorflow::StringPiece name) {
- auto it = c_find_if(module->computations(),
+ auto computations = module->computations();
+ auto it = c_find_if(computations,
[&](HloComputation* c) { return c->name() == name; });
- if (it == module->computations().end()) {
+ if (it == computations.end()) {
return nullptr;
}
return *it;
@@ -287,9 +288,10 @@ HloComputation* HloTestBase::FindComputation(HloModule* module,
HloInstruction* HloTestBase::FindInstruction(HloModule* module,
tensorflow::StringPiece name) {
for (const HloComputation* c : module->computations()) {
- auto it = c_find_if(c->instructions(),
+ auto instructions = c->instructions();
+ auto it = c_find_if(instructions,
[&](HloInstruction* i) { return i->name() == name; });
- if (it != c->instructions().end()) {
+ if (it != instructions.end()) {
return *it;
}
}
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h
index 9009d67cea..66719b1460 100644
--- a/tensorflow/compiler/xla/tests/hlo_test_base.h
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.h
@@ -200,6 +200,13 @@ class HloTestBase : public ::testing::Test {
->ResetLayout(layout);
}
+ void ForceResultLayout(HloModule* module, const Layout& layout,
+ ShapeIndexView shape_index) {
+ module->mutable_entry_computation_layout()
+ ->mutable_result_layout()
+ ->ResetLayout(layout, shape_index);
+ }
+
// Convenience method to clear the layout of the computation result in
// 'module'.
void ForceClearResultLayout(HloModule* module) {
diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h
index d1b8a6cf0b..31a099c15f 100644
--- a/tensorflow/compiler/xla/tests/literal_test_util.h
+++ b/tensorflow/compiler/xla/tests/literal_test_util.h
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/error_spec.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
@@ -154,20 +155,20 @@ class LiteralTestUtil {
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR0Equal(NativeT expected,
const LiteralSlice& actual) {
- EXPECT_TRUE(Equal(*Literal::CreateR0<NativeT>(expected), actual));
+ EXPECT_TRUE(Equal(*LiteralUtil::CreateR0<NativeT>(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR1Equal(
tensorflow::gtl::ArraySlice<NativeT> expected, const LiteralSlice& actual) {
- EXPECT_TRUE(Equal(*Literal::CreateR1<NativeT>(expected), actual));
+ EXPECT_TRUE(Equal(*LiteralUtil::CreateR1<NativeT>(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR2Equal(
std::initializer_list<std::initializer_list<NativeT>> expected,
const LiteralSlice& actual) {
- EXPECT_TRUE(Equal(*Literal::CreateR2<NativeT>(expected), actual));
+ EXPECT_TRUE(Equal(*LiteralUtil::CreateR2<NativeT>(expected), actual));
}
template <typename NativeT>
@@ -175,46 +176,46 @@ template <typename NativeT>
std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
expected,
const LiteralSlice& actual) {
- EXPECT_TRUE(Equal(*Literal::CreateR3<NativeT>(expected), actual));
+ EXPECT_TRUE(Equal(*LiteralUtil::CreateR3<NativeT>(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR2EqualArray2D(
const Array2D<NativeT>& expected, const LiteralSlice& actual) {
- EXPECT_TRUE(Equal(*Literal::CreateR2FromArray2D(expected), actual));
+ EXPECT_TRUE(Equal(*LiteralUtil::CreateR2FromArray2D(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR3EqualArray3D(
const Array3D<NativeT>& expected, const LiteralSlice& actual) {
- EXPECT_TRUE(Equal(*Literal::CreateR3FromArray3D(expected), actual));
+ EXPECT_TRUE(Equal(*LiteralUtil::CreateR3FromArray3D(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR4EqualArray4D(
const Array4D<NativeT>& expected, const LiteralSlice& actual) {
- EXPECT_TRUE(Equal(*Literal::CreateR4FromArray4D(expected), actual));
+ EXPECT_TRUE(Equal(*LiteralUtil::CreateR4FromArray4D(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR0Near(NativeT expected,
const LiteralSlice& actual,
const ErrorSpec& error) {
- EXPECT_TRUE(Near(*Literal::CreateR0<NativeT>(expected), actual, error));
+ EXPECT_TRUE(Near(*LiteralUtil::CreateR0<NativeT>(expected), actual, error));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR1Near(
tensorflow::gtl::ArraySlice<NativeT> expected, const LiteralSlice& actual,
const ErrorSpec& error) {
- EXPECT_TRUE(Near(*Literal::CreateR1<NativeT>(expected), actual, error));
+ EXPECT_TRUE(Near(*LiteralUtil::CreateR1<NativeT>(expected), actual, error));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR2Near(
std::initializer_list<std::initializer_list<NativeT>> expected,
const LiteralSlice& actual, const ErrorSpec& error) {
- EXPECT_TRUE(Near(*Literal::CreateR2<NativeT>(expected), actual, error));
+ EXPECT_TRUE(Near(*LiteralUtil::CreateR2<NativeT>(expected), actual, error));
}
template <typename NativeT>
@@ -222,7 +223,7 @@ template <typename NativeT>
std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
expected,
const LiteralSlice& actual, const ErrorSpec& error) {
- EXPECT_TRUE(Near(*Literal::CreateR3<NativeT>(expected), actual, error));
+ EXPECT_TRUE(Near(*LiteralUtil::CreateR3<NativeT>(expected), actual, error));
}
template <typename NativeT>
@@ -231,28 +232,28 @@ template <typename NativeT>
std::initializer_list<std::initializer_list<NativeT>>>>
expected,
const LiteralSlice& actual, const ErrorSpec& error) {
- EXPECT_TRUE(Near(*Literal::CreateR4<NativeT>(expected), actual, error));
+ EXPECT_TRUE(Near(*LiteralUtil::CreateR4<NativeT>(expected), actual, error));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR2NearArray2D(
const Array2D<NativeT>& expected, const LiteralSlice& actual,
const ErrorSpec& error) {
- EXPECT_TRUE(Near(*Literal::CreateR2FromArray2D(expected), actual, error));
+ EXPECT_TRUE(Near(*LiteralUtil::CreateR2FromArray2D(expected), actual, error));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR3NearArray3D(
const Array3D<NativeT>& expected, const LiteralSlice& actual,
const ErrorSpec& error) {
- EXPECT_TRUE(Near(*Literal::CreateR3FromArray3D(expected), actual, error));
+ EXPECT_TRUE(Near(*LiteralUtil::CreateR3FromArray3D(expected), actual, error));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR4NearArray4D(
const Array4D<NativeT>& expected, const LiteralSlice& actual,
const ErrorSpec& error) {
- EXPECT_TRUE(Near(*Literal::CreateR4FromArray4D(expected), actual, error));
+ EXPECT_TRUE(Near(*LiteralUtil::CreateR4FromArray4D(expected), actual, error));
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/literal_test_util_test.cc b/tensorflow/compiler/xla/tests/literal_test_util_test.cc
index bbac7285ae..f297b2b847 100644
--- a/tensorflow/compiler/xla/tests/literal_test_util_test.cc
+++ b/tensorflow/compiler/xla/tests/literal_test_util_test.cc
@@ -31,8 +31,9 @@ namespace xla {
namespace {
TEST(LiteralTestUtilTest, ComparesEqualTuplesEqual) {
- std::unique_ptr<Literal> literal = Literal::MakeTuple({
- Literal::CreateR0<int32>(42).get(), Literal::CreateR0<int32>(64).get(),
+ std::unique_ptr<Literal> literal = LiteralUtil::MakeTuple({
+ LiteralUtil::CreateR0<int32>(42).get(),
+ LiteralUtil::CreateR0<int32>(64).get(),
});
EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *literal));
}
@@ -42,11 +43,13 @@ TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) {
// un-fail an assertion failure. The CHECK-failure is death, so we can make a
// death assertion.
auto unequal_things_are_equal = [] {
- std::unique_ptr<Literal> lhs = Literal::MakeTuple({
- Literal::CreateR0<int32>(42).get(), Literal::CreateR0<int32>(64).get(),
+ std::unique_ptr<Literal> lhs = LiteralUtil::MakeTuple({
+ LiteralUtil::CreateR0<int32>(42).get(),
+ LiteralUtil::CreateR0<int32>(64).get(),
});
- std::unique_ptr<Literal> rhs = Literal::MakeTuple({
- Literal::CreateR0<int32>(64).get(), Literal::CreateR0<int32>(42).get(),
+ std::unique_ptr<Literal> rhs = LiteralUtil::MakeTuple({
+ LiteralUtil::CreateR0<int32>(64).get(),
+ LiteralUtil::CreateR0<int32>(42).get(),
});
CHECK(LiteralTestUtil::Equal(*lhs, *rhs)) << "LHS and RHS are unequal";
};
@@ -55,8 +58,8 @@ TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) {
TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) {
auto dummy_lambda = [] {
- auto two = Literal::CreateR0<float>(2);
- auto four = Literal::CreateR0<float>(4);
+ auto two = LiteralUtil::CreateR0<float>(2);
+ auto four = LiteralUtil::CreateR0<float>(4);
ErrorSpec error(0.001);
CHECK(LiteralTestUtil::Near(*two, *four, error)) << "two is not near four";
};
@@ -98,8 +101,8 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) {
}
TEST(LiteralTestUtilTest, NotEqualHasValuesInMessage) {
- auto expected = Literal::CreateR1<int32>({1, 2, 3});
- auto actual = Literal::CreateR1<int32>({4, 5, 6});
+ auto expected = LiteralUtil::CreateR1<int32>({1, 2, 3});
+ auto actual = LiteralUtil::CreateR1<int32>({4, 5, 6});
::testing::AssertionResult result =
LiteralTestUtil::Equal(*expected, *actual);
EXPECT_THAT(result.message(), ::testing::HasSubstr("expected: {1, 2, 3}"));
@@ -107,25 +110,26 @@ TEST(LiteralTestUtilTest, NotEqualHasValuesInMessage) {
}
TEST(LiteralTestUtilTest, NearComparatorR1) {
- auto a =
- Literal::CreateR1<float>({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
- auto b =
- Literal::CreateR1<float>({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
+ auto a = LiteralUtil::CreateR1<float>(
+ {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
+ auto b = LiteralUtil::CreateR1<float>(
+ {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
EXPECT_TRUE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001}));
}
TEST(LiteralTestUtilTest, NearComparatorR1Nan) {
- auto a =
- Literal::CreateR1<float>({0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8});
- auto b =
- Literal::CreateR1<float>({0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8});
+ auto a = LiteralUtil::CreateR1<float>(
+ {0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8});
+ auto b = LiteralUtil::CreateR1<float>(
+ {0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8});
EXPECT_TRUE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001}));
}
TEST(LiteralTestUtil, NearComparatorDifferentLengths) {
- auto a =
- Literal::CreateR1<float>({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
- auto b = Literal::CreateR1<float>({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7});
+ auto a = LiteralUtil::CreateR1<float>(
+ {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
+ auto b =
+ LiteralUtil::CreateR1<float>({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7});
EXPECT_FALSE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001}));
EXPECT_FALSE(LiteralTestUtil::Near(*b, *a, ErrorSpec{0.0001}));
}
diff --git a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc
index 082bc34136..13df83ffff 100644
--- a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc
+++ b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/llvm_compiler.h"
+#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h"
@@ -64,7 +65,7 @@ class LLVMCompilerTest : public ::testing::Test {
// Create HLO module, and run the compiler.
auto builder = HloComputation::Builder(TestName());
builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
auto hlo_module = CreateNewModule();
hlo_module->AddEntryComputation(builder.Build());
@@ -86,7 +87,7 @@ class LLVMCompilerTest : public ::testing::Test {
void TestMultiModuleCompilation(LLVMCompiler *compiler) {
HloComputation::Builder builder(TestName());
builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
std::unique_ptr<HloModule> hlo_module = CreateNewModule();
hlo_module->AddEntryComputation(builder.Build());
diff --git a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc
index 2c45f19c09..6fc1115097 100644
--- a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc
+++ b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <functional>
#include <utility>
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/tests/filecheck.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -25,28 +26,28 @@ limitations under the License.
namespace xla {
-void LLVMIRGenTestBase::SetIrHook(bool match_optimized_ir) {
+void LlvmIrGenTestBase::SetIrHook(bool match_optimized_ir) {
auto llvm_compiler = GetLLVMCompiler();
using std::placeholders::_1;
// Add the IR inspection hook to the LLVM compiler.
if (match_optimized_ir) {
llvm_compiler->SetPostOptimizationHook(
- std::bind(&LLVMIRGenTestBase::IrHook, this, _1));
+ std::bind(&LlvmIrGenTestBase::IrHook, this, _1));
} else {
llvm_compiler->SetPreOptimizationHook(
- std::bind(&LLVMIRGenTestBase::IrHook, this, _1));
+ std::bind(&LlvmIrGenTestBase::IrHook, this, _1));
}
}
-void LLVMIRGenTestBase::ResetIrHook() {
+void LlvmIrGenTestBase::ResetIrHook() {
auto llvm_compiler = GetLLVMCompiler();
llvm_compiler->RemovePreOptimizationHook();
llvm_compiler->RemovePostOptimizationHook();
}
-void LLVMIRGenTestBase::CompileAndVerifyIr(
+void LlvmIrGenTestBase::CompileAndVerifyIr(
std::unique_ptr<HloModule> hlo_module, const string& pattern,
bool match_optimized_ir) {
SetIrHook(match_optimized_ir);
@@ -58,7 +59,17 @@ void LLVMIRGenTestBase::CompileAndVerifyIr(
EXPECT_TRUE(filecheck_result.ValueOrDie());
}
-void LLVMIRGenTestBase::CompileAheadOfTimeAndVerifyIr(
+void LlvmIrGenTestBase::CompileAndVerifyIr(const string& hlo_text,
+ const string& expected_llvm_ir,
+ bool match_optimized_ir) {
+ HloModuleConfig config;
+ config.set_debug_options(GetDebugOptionsForTest());
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(hlo_text, config));
+ CompileAndVerifyIr(std::move(module), expected_llvm_ir, match_optimized_ir);
+}
+
+void LlvmIrGenTestBase::CompileAheadOfTimeAndVerifyIr(
std::unique_ptr<HloModule> hlo_module, const AotCompilationOptions& options,
const string& pattern, bool match_optimized_ir) {
SetIrHook(match_optimized_ir);
@@ -71,11 +82,11 @@ void LLVMIRGenTestBase::CompileAheadOfTimeAndVerifyIr(
EXPECT_TRUE(filecheck_result.ValueOrDie());
}
-LLVMCompiler* LLVMIRGenTestBase::GetLLVMCompiler() {
+LLVMCompiler* LlvmIrGenTestBase::GetLLVMCompiler() {
return static_cast<LLVMCompiler*>(backend().compiler());
}
-Status LLVMIRGenTestBase::IrHook(const llvm::Module& module) {
+Status LlvmIrGenTestBase::IrHook(const llvm::Module& module) {
ir_ = llvm_ir::DumpModuleToString(module);
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.h b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.h
index 74cbb5f5df..018f9546af 100644
--- a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.h
+++ b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.h
@@ -24,7 +24,7 @@ limitations under the License.
namespace xla {
// Tests that verify IR emitted by the CPU/GPU backend is as expected.
-class LLVMIRGenTestBase : public CodegenTestBase {
+class LlvmIrGenTestBase : public CodegenTestBase {
protected:
// Compiles the given HLO module to LLVM IR and verifies the IR matches the
// given pattern. `pattern` is in the FileCheck pattern matching syntax
@@ -38,6 +38,12 @@ class LLVMIRGenTestBase : public CodegenTestBase {
void CompileAndVerifyIr(std::unique_ptr<HloModule> hlo_module,
const string& pattern, bool match_optimized_ir);
+ // A thin wrapper around CompileAndVerifyIr that parses `hlo_text` to create
+ // an HLO module.
+ void CompileAndVerifyIr(const string& hlo_text,
+ const string& expected_llvm_ir,
+ bool match_optimized_ir = false);
+
// Compiles the given HLO module to LLVM IR and verifies the IR matches the
// given pattern. `pattern` is in the FileCheck pattern matching syntax
// (http://llvm.org/docs/CommandGuide/FileCheck.html).
diff --git a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc
index 9191be9fd9..0df50150ae 100644
--- a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc
+++ b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/local_service.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -45,7 +45,7 @@ XLA_TEST_F(LocalClientAllocationTest, AddVectors) {
TestAllocator* allocator = GetOrCreateAllocator(local_client_->platform());
auto x_array =
- LiteralToShapedBuffer(*Literal::CreateR1<float>({0.0f, 1.0f, 2.0f}));
+ LiteralToShapedBuffer(*LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
int64 allocation_count_before = allocator_->allocation_count();
diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc
index 2c6393794e..2f4d197ae6 100644
--- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc
+++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/local_service.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
@@ -68,7 +68,7 @@ XLA_TEST_F(LocalClientExecuteTest, AddScalars) {
auto y = ConstantR0<float>(&builder, 123.0f);
Add(x, y);
- auto x_value = LiteralToShapedBuffer(*Literal::CreateR0<float>(42.0f));
+ auto x_value = LiteralToShapedBuffer(*LiteralUtil::CreateR0<float>(42.0f));
ScopedShapedBuffer result =
ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_value});
LiteralTestUtil::ExpectR0Near<float>(165.f, *ShapedBufferToLiteral(result),
@@ -81,7 +81,7 @@ XLA_TEST_F(LocalClientExecuteTest, AddZeroElementVectors) {
auto y = ConstantR1<float>(&builder, {});
Add(x, y);
- auto x_array = LiteralToShapedBuffer(*Literal::CreateR1<float>({}));
+ auto x_array = LiteralToShapedBuffer(*LiteralUtil::CreateR1<float>({}));
ScopedShapedBuffer result =
ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_array});
LiteralTestUtil::ExpectR1Near<float>({}, *ShapedBufferToLiteral(result),
@@ -95,7 +95,7 @@ XLA_TEST_F(LocalClientExecuteTest, AddVectors) {
Add(x, y);
auto x_array =
- LiteralToShapedBuffer(*Literal::CreateR1<float>({0.0f, 1.0f, 2.0f}));
+ LiteralToShapedBuffer(*LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
ScopedShapedBuffer result =
ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_array});
LiteralTestUtil::ExpectR1Near<float>(
@@ -109,7 +109,7 @@ XLA_TEST_F(LocalClientExecuteTest, AddVectorsWithProfile) {
Add(x, y);
auto x_array =
- LiteralToShapedBuffer(*Literal::CreateR1<float>({0.0f, 1.0f, 2.0f}));
+ LiteralToShapedBuffer(*LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
ExecutionProfile profile;
ScopedShapedBuffer result = ExecuteLocallyOrDie(
builder.Build().ValueOrDie(), {&x_array}, DefaultExecutableBuildOptions(),
@@ -128,13 +128,13 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) {
auto computation = builder.Build().ConsumeValueOrDie();
// Create x as a col-major array.
- auto x_array = LiteralToShapedBuffer(*Literal::CreateR2WithLayout(
+ auto x_array = LiteralToShapedBuffer(*LiteralUtil::CreateR2WithLayout(
{{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1})));
EXPECT_TRUE(LayoutUtil::Equal(x_array.on_device_shape().layout(),
LayoutUtil::MakeLayout({0, 1})));
// Create y as a row-major array.
- auto y_array = LiteralToShapedBuffer(*Literal::CreateR2WithLayout(
+ auto y_array = LiteralToShapedBuffer(*LiteralUtil::CreateR2WithLayout(
{{10.0f, 20.0f}, {30.0f, 40.0f}}, LayoutUtil::MakeLayout({1, 0})));
EXPECT_TRUE(LayoutUtil::Equal(y_array.on_device_shape().layout(),
LayoutUtil::MakeLayout({1, 0})));
@@ -161,9 +161,9 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) {
auto computation = builder.Build().ConsumeValueOrDie();
auto x_array = LiteralToShapedBuffer(
- *Literal::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
+ *LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
auto y_array = LiteralToShapedBuffer(
- *Literal::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
+ *LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
// Run with col-major result layout.
ScopedShapedBuffer result_colmaj = ExecuteLocallyOrDie(
@@ -198,9 +198,9 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResult) {
auto computation = builder.Build().ConsumeValueOrDie();
auto x_array = LiteralToShapedBuffer(
- *Literal::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
+ *LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
auto y_array = LiteralToShapedBuffer(
- *Literal::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
+ *LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
ScopedShapedBuffer result =
ExecuteLocallyOrDie(computation, {&x_array, &y_array});
@@ -226,9 +226,9 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) {
auto computation = builder.Build().ConsumeValueOrDie();
auto x_array = LiteralToShapedBuffer(
- *Literal::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
+ *LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
auto y_array = LiteralToShapedBuffer(
- *Literal::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
+ *LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
ScopedShapedBuffer result =
ExecuteLocallyOrDie(computation, {&x_array, &y_array});
@@ -255,7 +255,7 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) {
Tuple(&builder, {x, y});
auto array = LiteralToShapedBuffer(
- *Literal::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
+ *LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
ExecutableBuildOptions options = DefaultExecutableBuildOptions();
Shape shape_with_layout = ShapeUtil::MakeTupleShape(
@@ -298,12 +298,12 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) {
Tuple(&builder, {array_sum, vector_diff});
auto computation = builder.Build().ConsumeValueOrDie();
- auto x_literal = Literal::MakeTuple(
- {Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}).get(),
- Literal::CreateR1<float>({42.0, 75.0, 123.0}).get()});
- auto y_literal = Literal::MakeTuple(
- {Literal::CreateR1<float>({2.0, 4.0, 6.0}).get(),
- Literal::CreateR2<float>({{55.0, 44.0}, {33.0, 22.0}}).get()});
+ auto x_literal = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}).get(),
+ LiteralUtil::CreateR1<float>({42.0, 75.0, 123.0}).get()});
+ auto y_literal = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR1<float>({2.0, 4.0, 6.0}).get(),
+ LiteralUtil::CreateR2<float>({{55.0, 44.0}, {33.0, 22.0}}).get()});
auto x_buffer = LiteralToShapedBuffer(*x_literal);
auto y_buffer = LiteralToShapedBuffer(*y_literal);
@@ -344,12 +344,12 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) {
Tuple(&builder, {negate_array, vector_sum});
auto computation = builder.Build().ConsumeValueOrDie();
- auto arg_literal = Literal::MakeTuple(
- {Literal::MakeTuple(
- {Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}).get(),
- Literal::CreateR1<float>({42.0, 75.0, 123.0}).get()})
+ auto arg_literal = LiteralUtil::MakeTuple(
+ {LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}).get(),
+ LiteralUtil::CreateR1<float>({42.0, 75.0, 123.0}).get()})
.get(),
- Literal::CreateR1<float>({222.0, -2.0, 10.0}).get()});
+ LiteralUtil::CreateR1<float>({222.0, -2.0, 10.0}).get()});
auto arg_buffer = LiteralToShapedBuffer(*arg_literal);
ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
@@ -377,9 +377,9 @@ XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) {
Tuple(&builder, {Neg(element_0), Add(element_1, element_1)});
auto computation = builder.Build().ConsumeValueOrDie();
- auto arg_literal = Literal::MakeTuple(
- {Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}).get(),
- Literal::CreateR2<float>({{11.0, 3.0}, {4.0, 5.0}}).get()});
+ auto arg_literal = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}).get(),
+ LiteralUtil::CreateR2<float>({{11.0, 3.0}, {4.0, 5.0}}).get()});
auto arg_buffer = LiteralToShapedBuffer(*arg_literal);
ScopedShapedBuffer result_0 = ExecuteLocallyOrDie(computation, {&arg_buffer});
@@ -429,10 +429,10 @@ XLA_TEST_F(LocalClientExecuteTest, LargeTuple) {
// -tuple_index}.
std::vector<std::unique_ptr<Literal>> arg_elements;
for (int i = 0; i < kElementCount; ++i) {
- arg_elements.push_back(Literal::CreateR1<float>({1.0f * i, -1.0f * i}));
+ arg_elements.push_back(LiteralUtil::CreateR1<float>({1.0f * i, -1.0f * i}));
}
std::unique_ptr<Literal> arg_literal =
- Literal::MakeTupleOwned(std::move(arg_elements));
+ LiteralUtil::MakeTupleOwned(std::move(arg_elements));
auto arg_buffer = LiteralToShapedBuffer(*arg_literal);
ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
@@ -480,12 +480,13 @@ XLA_TEST_F(LocalClientExecuteTest, LargeNestedTuple) {
for (int i = 0; i < kFanout; ++i) {
std::vector<std::unique_ptr<Literal>> inner_tuple_elements;
for (int j = 0; j < kFanout; ++j) {
- inner_tuple_elements.push_back(Literal::CreateR0<float>(i + j));
+ inner_tuple_elements.push_back(LiteralUtil::CreateR0<float>(i + j));
}
outer_tuple_elements.push_back(
- Literal::MakeTupleOwned(std::move(inner_tuple_elements)));
+ LiteralUtil::MakeTupleOwned(std::move(inner_tuple_elements)));
}
- auto arg_literal = Literal::MakeTupleOwned(std::move(outer_tuple_elements));
+ auto arg_literal =
+ LiteralUtil::MakeTupleOwned(std::move(outer_tuple_elements));
auto arg_buffer = LiteralToShapedBuffer(*arg_literal);
ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
@@ -524,11 +525,11 @@ XLA_TEST_F(LocalClientExecuteTest, DeepTuple) {
auto computation = builder.Build().ConsumeValueOrDie();
// Construct the argument to pass to the computation.
- std::unique_ptr<Literal> arg_literal = Literal::CreateR0<float>(123.0);
+ std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR0<float>(123.0);
for (int i = 0; i < kTupleDepth; ++i) {
std::vector<std::unique_ptr<Literal>> arg_vector;
arg_vector.push_back(std::move(arg_literal));
- arg_literal = Literal::MakeTupleOwned(std::move(arg_vector));
+ arg_literal = LiteralUtil::MakeTupleOwned(std::move(arg_vector));
}
auto arg_buffer = LiteralToShapedBuffer(*arg_literal);
@@ -551,7 +552,7 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) {
Add(x, y);
auto x_array =
- LiteralToShapedBuffer(*Literal::CreateR1<float>({1.0f, 2.0f, 3.0f}));
+ LiteralToShapedBuffer(*LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f}));
auto execute_status =
ExecuteLocally(builder.Build().ValueOrDie(), {&x_array});
@@ -567,7 +568,7 @@ XLA_TEST_F(LocalClientExecuteTest, IncorrectArgumentShape) {
Neg(x);
auto x_array = LiteralToShapedBuffer(
- *Literal::CreateR2<float>({{0.0f, 1.0f}, {2.0f, 3.0f}}));
+ *LiteralUtil::CreateR2<float>({{0.0f, 1.0f}, {2.0f, 3.0f}}));
auto execute_status =
ExecuteLocally(builder.Build().ValueOrDie(), {&x_array});
@@ -584,7 +585,7 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidResultLayout) {
Neg(x);
auto x_array = LiteralToShapedBuffer(
- *Literal::CreateR2<float>({{0.0f, 1.0f}, {2.0f, 3.0f}}));
+ *LiteralUtil::CreateR2<float>({{0.0f, 1.0f}, {2.0f, 3.0f}}));
auto execute_status = ExecuteLocally(
builder.Build().ValueOrDie(), {&x_array},
DefaultExecutableBuildOptions().set_result_layout(
@@ -767,14 +768,10 @@ XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) {
executable_status.ConsumeValueOrDie();
auto x_array =
- LiteralToShapedBuffer(*Literal::CreateR1<float>({0.0f, 1.0f, 2.0f}));
+ LiteralToShapedBuffer(*LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
ScopedShapedBuffer result =
executable->Run({&x_array}, DefaultExecutableRunOptions())
.ConsumeValueOrDie();
- ASSERT_IS_OK(local_client_->mutable_backend()
- ->BorrowStream(0)
- .ValueOrDie()
- ->BlockHostUntilDone());
LiteralTestUtil::ExpectR1Near<float>(
{2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_);
@@ -795,29 +792,29 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion) {
};
// Array shapes.
- test_to_device_and_back(*Literal::CreateR0<float>(42.0));
- test_to_device_and_back(*Literal::CreateR0<bool>(true));
- test_to_device_and_back(*Literal::CreateR1<float>({1.0, 42.0, 744.4}));
+ test_to_device_and_back(*LiteralUtil::CreateR0<float>(42.0));
+ test_to_device_and_back(*LiteralUtil::CreateR0<bool>(true));
+ test_to_device_and_back(*LiteralUtil::CreateR1<float>({1.0, 42.0, 744.4}));
test_to_device_and_back(
- *Literal::CreateR2<float>({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}}));
- test_to_device_and_back(*Literal::CreateR2<int32>({{2, 1}, {4444, 56}}));
+ *LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}}));
+ test_to_device_and_back(*LiteralUtil::CreateR2<int32>({{2, 1}, {4444, 56}}));
// Null shape (empty tuple).
- test_to_device_and_back(*Literal::MakeTuple({}));
+ test_to_device_and_back(*LiteralUtil::MakeTuple({}));
// Non-nested tuples.
test_to_device_and_back(
- *Literal::MakeTuple({Literal::CreateR0<float>(12223.0).get()}));
+ *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(12223.0).get()}));
test_to_device_and_back(
- *Literal::MakeTuple({Literal::CreateR1<float>({1.0, -42.0}).get(),
- Literal::CreateR0<float>(123456.0).get()}));
+ *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1.0, -42.0}).get(),
+ LiteralUtil::CreateR0<float>(123456.0).get()}));
// Nested tuple.
- test_to_device_and_back(*Literal::MakeTuple(
- {Literal::MakeTuple({Literal::CreateR1<float>({1.0, -42.0}).get(),
- Literal::CreateR0<float>(123456.0).get()})
+ test_to_device_and_back(*LiteralUtil::MakeTuple(
+ {LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1.0, -42.0}).get(),
+ LiteralUtil::CreateR0<float>(123456.0).get()})
.get(),
- Literal::CreateR0<bool>(false).get()}));
+ LiteralUtil::CreateR0<bool>(false).get()}));
}
XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) {
@@ -835,13 +832,13 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) {
};
test_to_device_and_back(
- *Literal::CreateR2<double>({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}}));
- test_to_device_and_back(*Literal::CreateR2<int64>({{2, 1}, {4444, 56}}));
+ *LiteralUtil::CreateR2<double>({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}}));
+ test_to_device_and_back(*LiteralUtil::CreateR2<int64>({{2, 1}, {4444, 56}}));
test_to_device_and_back(
- *Literal::CreateR2<uint64>({{20000000000ULL, 1}, {4444, 56}}));
- test_to_device_and_back(
- *Literal::MakeTuple({Literal::CreateR1<double>({1.0, -42.0}).get(),
- Literal::CreateR0<int64>(123456789000LL).get()}));
+ *LiteralUtil::CreateR2<uint64>({{20000000000ULL, 1}, {4444, 56}}));
+ test_to_device_and_back(*LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR1<double>({1.0, -42.0}).get(),
+ LiteralUtil::CreateR0<int64>(123456789000LL).get()}));
}
XLA_TEST_F(LocalClientExecuteTest, InfeedTest) {
@@ -860,7 +857,7 @@ XLA_TEST_F(LocalClientExecuteTest, InfeedTest) {
}));
ASSERT_IS_OK(local_client_->TransferToInfeedLocal(
- *Literal::CreateR1<float>({-5.0, 123.0, 42.0}),
+ *LiteralUtil::CreateR1<float>({-5.0, 123.0, 42.0}),
local_client_->default_device_ordinal()));
// Join the thread.
@@ -869,9 +866,7 @@ XLA_TEST_F(LocalClientExecuteTest, InfeedTest) {
LiteralTestUtil::ExpectR1Equal<float>({-4.0, 125.0, 45.0}, *result);
}
-// TODO(b/34359662): Support infeed/outfeed on GPU and CPU parallel.
-// 2017-10-18.
-XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_GPU(InfeedOutfeedTest)) {
+XLA_TEST_F(LocalClientExecuteTest, InfeedOutfeedTest) {
XlaBuilder builder(TestName());
const Shape shape = ShapeUtil::MakeShape(F32, {3});
auto in = Infeed(&builder, shape);
@@ -885,7 +880,7 @@ XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_GPU(InfeedOutfeedTest)) {
[&] { ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); }));
ASSERT_IS_OK(local_client_->TransferToInfeedLocal(
- *Literal::CreateR1<float>({-5.0, 123.0, 42.0}),
+ *LiteralUtil::CreateR1<float>({-5.0, 123.0, 42.0}),
local_client_->default_device_ordinal()));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
@@ -920,7 +915,7 @@ void BM_LocalClientOverhead(int num_iters) {
transfer_manager
->AllocateScopedShapedBuffer(shape, &allocator, /*device_ordinal=*/0)
.ConsumeValueOrDie();
- auto literal = Literal::CreateR2<float>({{0, 0, 0}, {0, 0, 0}});
+ auto literal = LiteralUtil::CreateR2<float>({{0, 0, 0}, {0, 0, 0}});
auto stream =
client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie();
ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice(stream.get(), *literal,
diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc
index c31ba0e713..88797a7d0a 100644
--- a/tensorflow/compiler/xla/tests/local_client_test_base.cc
+++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc
@@ -189,19 +189,7 @@ StatusOr<ScopedShapedBuffer> LocalClientTestBase::ExecuteLocally(
TF_ASSIGN_OR_RETURN(
std::unique_ptr<LocalExecutable> executable,
local_client_->Compile(computation, argument_layouts, build_options));
- TF_ASSIGN_OR_RETURN(auto ret, executable->Run(arguments, run_options));
-
- auto device_ordinal =
- build_options.device_ordinal() == -1 ? 0 : build_options.device_ordinal();
- auto* stream = run_options.stream();
- if (!stream) {
- stream = local_client_->mutable_backend()
- ->BorrowStream(device_ordinal)
- .ValueOrDie()
- .get();
- }
- TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
- return std::move(ret);
+ return executable->Run(arguments, run_options);
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc
index 1b3bc9d504..7ddc636931 100644
--- a/tensorflow/compiler/xla/tests/map_test.cc
+++ b/tensorflow/compiler/xla/tests/map_test.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
@@ -169,7 +169,7 @@ class MapTest : public ClientLibraryTestBase {
TEST_F(MapTest, MapEachElemPlusOneR0) {
// Applies lambda (x) (+ x 1)) to an input scalar.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = Literal::CreateR0<float>(42.0);
+ std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR0<float>(42.0);
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
@@ -183,7 +183,7 @@ TEST_F(MapTest, MapEachElemPlusOneR0) {
XLA_TEST_F(MapTest, MapEachElemPlusOneR1S0) {
// Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = Literal::CreateR1<float>({});
+ std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR1<float>({});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
@@ -198,7 +198,7 @@ TEST_F(MapTest, MapEachElemPlusOneR1S4) {
// Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4.
XlaBuilder builder(TestName());
std::unique_ptr<Literal> param0_literal =
- Literal::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
+ LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
@@ -212,7 +212,7 @@ TEST_F(MapTest, MapEachElemPlusOneR1S4) {
TEST_F(MapTest, MapEachF32ElementToS32Constant) {
XlaBuilder builder(TestName());
std::unique_ptr<Literal> param0_literal =
- Literal::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
+ LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
@@ -225,7 +225,7 @@ TEST_F(MapTest, MapEachF32ElementToS32Constant) {
TEST_F(MapTest, MapEachF32ElementToU32Constant) {
XlaBuilder builder(TestName());
std::unique_ptr<Literal> param0_literal =
- Literal::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
+ LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
@@ -239,7 +239,7 @@ TEST_F(MapTest, MapEachElemLongerChainR1) {
// Maps (lambda (x) (* (+ x 1) x)) onto an input R1F32 vector.
XlaBuilder builder(TestName());
std::unique_ptr<Literal> param0_literal =
- Literal::CreateR1<float>({2.6f, -5.1f, 0.1f, 0.2f, 999.0f, 255.5f});
+ LiteralUtil::CreateR1<float>({2.6f, -5.1f, 0.1f, 0.2f, 999.0f, 255.5f});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
@@ -255,7 +255,7 @@ XLA_TEST_F(MapTest, MapMultipleMapsR1S0) {
// Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0, and then
// maps (lambda (x) (* x 2)) on the result.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = Literal::CreateR1<float>({});
+ std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR1<float>({});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
@@ -272,7 +272,7 @@ TEST_F(MapTest, MapMultipleMapsR1S4) {
// maps (lambda (x) (* x 2)) on the result.
XlaBuilder builder(TestName());
std::unique_ptr<Literal> param0_literal =
- Literal::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
+ LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
@@ -287,7 +287,7 @@ TEST_F(MapTest, MapMultipleMapsR1S4) {
TEST_F(MapTest, MapEachElemPlusOneR2) {
// Maps (lambda (x) (+ x 1)) onto an input R2F32 vector.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = Literal::CreateR2<float>(
+ std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR2<float>(
{{13.25f, 14.0f}, {-7.1f, -7.2f}, {-8.8f, 8.8f}});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
@@ -343,11 +343,11 @@ TEST_F(MapTest, MapBinaryAdder) {
// Maps (lambda (x y) (+ x y)) onto two R1F32 vectors.
XlaBuilder builder(TestName());
std::unique_ptr<Literal> param0_literal =
- Literal::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
+ LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
std::unique_ptr<Literal> param1_literal =
- Literal::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
+ LiteralUtil::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
std::unique_ptr<GlobalData> param1_data =
client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
@@ -365,12 +365,12 @@ TEST_F(MapTest, MapBinaryAdder) {
// for Map that used to fail in shape inference (b/28989438).
XLA_TEST_F(MapTest, AddWithMixedLayouts) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = Literal::CreateR2WithLayout(
+ std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR2WithLayout(
{{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({1, 0}));
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
- std::unique_ptr<Literal> param1_literal = Literal::CreateR2WithLayout(
+ std::unique_ptr<Literal> param1_literal = LiteralUtil::CreateR2WithLayout(
{{10, 20}, {30, 40}}, LayoutUtil::MakeLayout({0, 1}));
std::unique_ptr<GlobalData> param1_data =
client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
@@ -392,12 +392,12 @@ XLA_TEST_F(MapTest, AddWithMixedLayouts) {
XLA_TEST_F(MapTest, AddR3_3x0x2) {
XlaBuilder builder(TestName());
std::unique_ptr<Literal> param0_literal =
- Literal::CreateR3FromArray3D<int32>(Array3D<int32>(3, 0, 2));
+ LiteralUtil::CreateR3FromArray3D<int32>(Array3D<int32>(3, 0, 2));
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
std::unique_ptr<Literal> param1_literal =
- Literal::CreateR3FromArray3D<int32>(Array3D<int32>(3, 0, 2));
+ LiteralUtil::CreateR3FromArray3D<int32>(Array3D<int32>(3, 0, 2));
std::unique_ptr<GlobalData> param1_data =
client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
@@ -414,15 +414,15 @@ TEST_F(MapTest, MapTernaryAdder) {
// Maps (lambda (x y z) (+ x y z)) onto three R1F32 vectors.
XlaBuilder builder(TestName());
std::unique_ptr<Literal> param0_literal =
- Literal::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
+ LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
std::unique_ptr<Literal> param1_literal =
- Literal::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
+ LiteralUtil::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
std::unique_ptr<GlobalData> param1_data =
client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
std::unique_ptr<Literal> param2_literal =
- Literal::CreateR1<float>({-10.0f, -100.0f, -900.0f, -400.0f});
+ LiteralUtil::CreateR1<float>({-10.0f, -100.0f, -900.0f, -400.0f});
std::unique_ptr<GlobalData> param2_data =
client_->TransferToServer(*param2_literal).ConsumeValueOrDie();
@@ -476,11 +476,11 @@ TEST_F(MapTest, MapOperantionWithBuildError) {
auto error_add = sub_builder->BuildAndNoteError();
std::unique_ptr<Literal> param0_literal =
- Literal::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
+ LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
std::unique_ptr<Literal> param1_literal =
- Literal::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
+ LiteralUtil::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
std::unique_ptr<GlobalData> param1_data =
client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
@@ -513,8 +513,8 @@ TEST_F(MapTestWithFullOpt, MapScalarPower) {
Pow(x, y);
auto power = sub_builder->BuildAndNoteError();
- std::unique_ptr<Literal> param0_literal = Literal::CreateR0<float>(2.0f);
- std::unique_ptr<Literal> param1_literal = Literal::CreateR0<float>(5.0f);
+ std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR0<float>(2.0f);
+ std::unique_ptr<Literal> param1_literal = LiteralUtil::CreateR0<float>(5.0f);
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> param1_data =
@@ -540,8 +540,8 @@ TEST_F(MapTestWithFullOpt, MapSubtractOppositeOrder) {
Sub(y, x); // note that this is y - x, not x - y
auto sub_opposite = sub_builder->BuildAndNoteError();
- std::unique_ptr<Literal> param0_literal = Literal::CreateR0<float>(2.0f);
- std::unique_ptr<Literal> param1_literal = Literal::CreateR0<float>(5.0f);
+ std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR0<float>(2.0f);
+ std::unique_ptr<Literal> param1_literal = LiteralUtil::CreateR0<float>(5.0f);
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> param1_data =
@@ -565,7 +565,7 @@ TEST_F(MapTestWithFullOpt, MapSquare) {
Mul(x, x);
auto square = sub_builder->BuildAndNoteError();
- std::unique_ptr<Literal> param0_literal = Literal::CreateR0<float>(10.0f);
+ std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR0<float>(10.0f);
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc
index 17b1807f44..069b8a881f 100644
--- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc
+++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -63,8 +63,8 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, ExpTwoByTwoValues) {
Exp(data);
std::unique_ptr<Literal> expected =
- Literal::CreateR2FromArray2D<T>({{2.71828f, 1.00000f}, // row 0
- {0.36788f, 1.64872f}}); // row 1
+ LiteralUtil::CreateR2FromArray2D<T>({{2.71828f, 1.00000f}, // row 0
+ {0.36788f, 1.64872f}}); // row 1
this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5));
}
@@ -92,8 +92,8 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MapTwoByTwo) {
Map(&builder, {data}, add_half, {0, 1});
std::unique_ptr<Literal> expected =
- Literal::CreateR2FromArray2D<T>({{1.5f, 0.5f}, // row 0
- {-0.5f, 1.0f}}); // row 1
+ LiteralUtil::CreateR2FromArray2D<T>({{1.5f, 0.5f}, // row 0
+ {-0.5f, 1.0f}}); // row 1
this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5));
}
@@ -111,8 +111,8 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MaxTwoByTwoValues) {
Max(lhs, rhs);
std::unique_ptr<Literal> expected =
- Literal::CreateR2FromArray2D<T>({{7.0f, 6.0f}, // row 0
- {3.0f, -4.0f}}); // row 1
+ LiteralUtil::CreateR2FromArray2D<T>({{7.0f, 6.0f}, // row 0
+ {3.0f, -4.0f}}); // row 1
this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6));
}
@@ -200,12 +200,14 @@ class MatOpsDotAddTest
TF_ASSERT_OK_AND_ASSIGN(
auto lhs_handle,
- client_->TransferToServer(*Literal::CreateR2FromArray2DWithLayout<T>(
- lhs, LayoutUtil::MakeLayout(minor_to_major(row_major)))));
+ client_->TransferToServer(
+ *LiteralUtil::CreateR2FromArray2DWithLayout<T>(
+ lhs, LayoutUtil::MakeLayout(minor_to_major(row_major)))));
TF_ASSERT_OK_AND_ASSIGN(
auto rhs_handle,
- client_->TransferToServer(*Literal::CreateR2FromArray2DWithLayout<T>(
- rhs, LayoutUtil::MakeLayout(minor_to_major(row_major)))));
+ client_->TransferToServer(
+ *LiteralUtil::CreateR2FromArray2DWithLayout<T>(
+ rhs, LayoutUtil::MakeLayout(minor_to_major(row_major)))));
XlaBuilder builder(TestName());
auto lhs_arg = Parameter(&builder, 0, lhs_shape, "lhs");
diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
index 6597748c8d..eb06b115da 100644
--- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
+++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include <utility>
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -60,7 +60,7 @@ class MultiOutputFusionTest : public HloTestBase {
const Shape elem_shape2 = ShapeUtil::MakeShape(F32, {size, size});
auto const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(8.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(8.0f)));
auto param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, elem_shape0, "0"));
@@ -105,8 +105,9 @@ class MultiOutputFusionTest : public HloTestBase {
Literal expect(ShapeUtil::MakeShape(F32, {size, size}));
expect.PopulateWithValue<float>(size * 1.5f * 3.5f);
- auto actual = ExecuteAndTransfer(
- std::move(hlo_module), {Literal::CreateR0<float>(-9.0f).get(), &arg1});
+ auto actual =
+ ExecuteAndTransfer(std::move(hlo_module),
+ {LiteralUtil::CreateR0<float>(-9.0f).get(), &arg1});
EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_));
}
@@ -165,7 +166,8 @@ class MultiOutputFusionTest : public HloTestBase {
Literal input1(ShapeUtil::MakeShape(F64, {size}));
input1.PopulateWithValue(1.);
- Literal expect = std::move(*Literal::CreateR1<float>({size * 1.5f * 3.5f}));
+ Literal expect =
+ std::move(*LiteralUtil::CreateR1<float>({size * 1.5f * 3.5f}));
auto actual = ExecuteAndTransfer(std::move(hlo_module), {&input0, &input1});
EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_));
}
@@ -198,16 +200,16 @@ XLA_TEST_F(MultiOutputFusionTest, FusionNodeIsRoot) {
auto module =
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
- auto param = Literal::MakeTupleOwned(
- Literal::MakeTupleOwned(
- Literal::MakeTupleOwned(Literal::CreateR0<int32>(42)),
- Literal::CreateR0<float>(1.0)),
- Literal::MakeTupleOwned(Literal::CreateR0<float>(3.0),
- Literal::CreateR0<int32>(4)));
+ auto param = LiteralUtil::MakeTupleOwned(
+ LiteralUtil::MakeTupleOwned(
+ LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0<int32>(42)),
+ LiteralUtil::CreateR0<float>(1.0)),
+ LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0<float>(3.0),
+ LiteralUtil::CreateR0<int32>(4)));
std::unique_ptr<Literal> result =
ExecuteNoHloPasses(std::move(module), {param.get()});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *Literal::MakeTupleOwned(Literal::CreateR0<int32>(42)), *result));
+ *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0<int32>(42)), *result));
}
XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) {
@@ -232,7 +234,7 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) {
auto module =
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
- auto param = Literal::CreateR1<float>({1.0, 2.0, 3.0, -1.0});
+ auto param = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, -1.0});
std::unique_ptr<Literal> result =
ExecuteNoHloPasses(std::move(module), {param.get()});
LiteralTestUtil::ExpectR1Equal<float>({0.0, 4.0, 9.0, 1.0}, *result);
@@ -265,7 +267,7 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) {
auto module =
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
- auto param = Literal::CreateR1<float>({1.0, 2.0, 3.0});
+ auto param = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0});
std::unique_ptr<Literal> result =
ExecuteNoHloPasses(std::move(module), {param.get()});
LiteralTestUtil::ExpectR1Equal<float>({0.0, 4.0, 9.0}, *result);
@@ -308,12 +310,14 @@ XLA_TEST_F(MultiOutputFusionTest,
auto module =
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
- auto param = Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
+ auto param =
+ LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
std::unique_ptr<Literal> result =
ExecuteNoHloPasses(std::move(module), {param.get()});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *Literal::MakeTupleOwned(Literal::CreateR2<float>({{3, 7}, {11, 15}}),
- Literal::CreateR2<float>({{5, 16}, {36, 64}})),
+ *LiteralUtil::MakeTupleOwned(
+ LiteralUtil::CreateR2<float>({{3, 7}, {11, 15}}),
+ LiteralUtil::CreateR2<float>({{5, 16}, {36, 64}})),
*result));
}
@@ -338,12 +342,14 @@ XLA_TEST_F(MultiOutputFusionTest,
auto module =
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
- auto param = Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
+ auto param =
+ LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
std::unique_ptr<Literal> result =
ExecuteNoHloPasses(std::move(module), {param.get()});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *Literal::MakeTupleOwned(Literal::CreateR2<float>({{6, 8}, {10, 12}}),
- Literal::CreateR2<float>({{25, 36}, {49, 64}})),
+ *LiteralUtil::MakeTupleOwned(
+ LiteralUtil::CreateR2<float>({{6, 8}, {10, 12}}),
+ LiteralUtil::CreateR2<float>({{25, 36}, {49, 64}})),
*result));
}
@@ -369,13 +375,14 @@ XLA_TEST_F(MultiOutputFusionTest,
auto module =
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
- auto param = Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
+ auto param =
+ LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
std::unique_ptr<Literal> result =
ExecuteNoHloPasses(std::move(module), {param.get()});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *Literal::MakeTupleOwned(Literal::CreateR1<float>({14, 22}),
- Literal::CreateR1<float>({36, 64}),
- Literal::CreateR1<float>({66, 138})),
+ *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({14, 22}),
+ LiteralUtil::CreateR1<float>({36, 64}),
+ LiteralUtil::CreateR1<float>({66, 138})),
*result));
}
@@ -401,14 +408,15 @@ XLA_TEST_F(MultiOutputFusionTest,
auto module =
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
- auto param = Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
+ auto param =
+ LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
std::unique_ptr<Literal> result =
ExecuteNoHloPasses(std::move(module), {param.get()});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *Literal::MakeTupleOwned(
- Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}),
- Literal::CreateR2<float>({{3, 7}, {11, 15}}),
- Literal::CreateR2<float>({{5, 16}, {36, 64}})),
+ *LiteralUtil::MakeTupleOwned(
+ LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}),
+ LiteralUtil::CreateR2<float>({{3, 7}, {11, 15}}),
+ LiteralUtil::CreateR2<float>({{5, 16}, {36, 64}})),
*result));
}
@@ -434,14 +442,16 @@ XLA_TEST_F(MultiOutputFusionTest,
auto module =
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
- auto param = Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
+ auto param =
+ LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
std::unique_ptr<Literal> result =
ExecuteNoHloPasses(std::move(module), {param.get()});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *Literal::MakeTupleOwned(
- Literal::CreateR2<float>({{6, 8}, {10, 12}}),
- Literal::CreateR3<float>({{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}),
- Literal::CreateR2<float>({{25, 36}, {49, 64}})),
+ *LiteralUtil::MakeTupleOwned(
+ LiteralUtil::CreateR2<float>({{6, 8}, {10, 12}}),
+ LiteralUtil::CreateR3<float>(
+ {{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}),
+ LiteralUtil::CreateR2<float>({{25, 36}, {49, 64}})),
*result));
}
@@ -468,14 +478,16 @@ XLA_TEST_F(MultiOutputFusionTest,
auto module =
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
- auto param = Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
+ auto param =
+ LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
std::unique_ptr<Literal> result =
ExecuteNoHloPasses(std::move(module), {param.get()});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *Literal::MakeTupleOwned(
- Literal::CreateR1<float>({14, 22}),
- Literal::CreateR3<float>({{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}),
- Literal::CreateR3<float>(
+ *LiteralUtil::MakeTupleOwned(
+ LiteralUtil::CreateR1<float>({14, 22}),
+ LiteralUtil::CreateR3<float>(
+ {{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}),
+ LiteralUtil::CreateR3<float>(
{{{5, 10}, {15, 20}}, {{25, 30}, {35, 40}}})),
*result));
}
@@ -502,15 +514,16 @@ XLA_TEST_F(MultiOutputFusionTest,
auto module =
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
- auto param = Literal::CreateR3<float>({{{0, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
- auto init1 = Literal::CreateR0<float>(5);
- auto init2 = Literal::CreateR0<float>(6);
+ auto param =
+ LiteralUtil::CreateR3<float>({{{0, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
+ auto init1 = LiteralUtil::CreateR0<float>(5);
+ auto init2 = LiteralUtil::CreateR0<float>(6);
std::unique_ptr<Literal> result = ExecuteNoHloPasses(
std::move(module), {param.get(), init1.get(), init2.get()});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *Literal::MakeTupleOwned(
- Literal::CreateR2<float>({{167, 172}, {176, 180}}),
- Literal::CreateR2<float>({{6, 6}, {6, 8}})),
+ *LiteralUtil::MakeTupleOwned(
+ LiteralUtil::CreateR2<float>({{167, 172}, {176, 180}}),
+ LiteralUtil::CreateR2<float>({{6, 6}, {6, 8}})),
*result));
}
@@ -537,19 +550,20 @@ XLA_TEST_F(MultiOutputFusionTest,
auto module =
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
- auto param = Literal::CreateR3<Eigen::half>(
+ auto param = LiteralUtil::CreateR3<Eigen::half>(
{{{Eigen::half(1), Eigen::half(2)}, {Eigen::half(3), Eigen::half(4)}},
{{Eigen::half(5), Eigen::half(6)}, {Eigen::half(7), Eigen::half(8)}}});
std::unique_ptr<Literal> result =
ExecuteNoHloPasses(std::move(module), {param.get()});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *Literal::MakeTupleOwned(
- Literal::CreateR2<float>({{3, 7}, {11, 15}}),
- Literal::CreateR2<float>({{5, 16}, {36, 64}}),
- Literal::CreateR3<Eigen::half>({{{Eigen::half(1), Eigen::half(2)},
- {Eigen::half(3), Eigen::half(4)}},
- {{Eigen::half(5), Eigen::half(6)},
- {Eigen::half(7), Eigen::half(8)}}})),
+ *LiteralUtil::MakeTupleOwned(
+ LiteralUtil::CreateR2<float>({{3, 7}, {11, 15}}),
+ LiteralUtil::CreateR2<float>({{5, 16}, {36, 64}}),
+ LiteralUtil::CreateR3<Eigen::half>(
+ {{{Eigen::half(1), Eigen::half(2)},
+ {Eigen::half(3), Eigen::half(4)}},
+ {{Eigen::half(5), Eigen::half(6)},
+ {Eigen::half(7), Eigen::half(8)}}})),
*result));
}
diff --git a/tensorflow/compiler/xla/tests/pad_test.cc b/tensorflow/compiler/xla/tests/pad_test.cc
index 2e5081bbcb..e428fa9b5e 100644
--- a/tensorflow/compiler/xla/tests/pad_test.cc
+++ b/tensorflow/compiler/xla/tests/pad_test.cc
@@ -93,8 +93,8 @@ XLA_TEST_P(PadTestFloat, Pad1DS0ToS0Array) {
dimension->set_edge_padding_high(0);
dimension->set_interior_padding(0);
- Pad(AddParam(*Literal::CreateR1<float>({}), &b),
- AddParam(*Literal::CreateR0<float>(0.1), &b), padding_config);
+ Pad(AddParam(*LiteralUtil::CreateR1<float>({}), &b),
+ AddParam(*LiteralUtil::CreateR0<float>(0.1), &b), padding_config);
ComputeAndCompareR1<float>(&b, {}, {}, DefaultErrorSpec());
}
@@ -108,8 +108,8 @@ XLA_TEST_P(PadTestFloat, Pad1DS0ToS5Array) {
dimension->set_edge_padding_high(4);
dimension->set_interior_padding(7);
- Pad(AddParam(*Literal::CreateR1<float>({}), &b),
- AddParam(*Literal::CreateR0<float>(0.1), &b), padding_config);
+ Pad(AddParam(*LiteralUtil::CreateR1<float>({}), &b),
+ AddParam(*LiteralUtil::CreateR0<float>(0.1), &b), padding_config);
ComputeAndCompareR1<float>(&b, std::vector<float>(5, 0.1), {},
DefaultErrorSpec());
}
@@ -123,8 +123,8 @@ XLA_TEST_P(PadTestFloat, Pad1DS3Array) {
dimension->set_edge_padding_high(0);
dimension->set_interior_padding(1);
- Pad(AddParam(*Literal::CreateR1<float>({1, 2, 3}), &b),
- AddParam(*Literal::CreateR0<float>(0.1), &b), padding_config);
+ Pad(AddParam(*LiteralUtil::CreateR1<float>({1, 2, 3}), &b),
+ AddParam(*LiteralUtil::CreateR0<float>(0.1), &b), padding_config);
std::vector<float> expected({0.1, 0.1, 0.1, 1, 0.1, 2, 0.1, 3});
ComputeAndCompareR1<float>(&b, expected, {}, DefaultErrorSpec());
}
@@ -132,7 +132,8 @@ XLA_TEST_P(PadTestFloat, Pad1DS3Array) {
XLA_TEST_P(PadTestFloat, Pad4D_2x0x3x2_FloatArray) {
XlaBuilder b(TestName());
Pad(AddParam(Array4D<float>(2, 0, 3, 2), &b),
- AddParam(*Literal::CreateR0<float>(1.5), &b), r4_padding_on_dim0_dim1_);
+ AddParam(*LiteralUtil::CreateR0<float>(1.5), &b),
+ r4_padding_on_dim0_dim1_);
ComputeAndCompareR4<float>(&b, Array4D<float>(5, 2, 3, 2, 1.5f), {},
DefaultErrorSpec());
}
@@ -147,7 +148,7 @@ TEST_P(PadTestFloat, Pad4DFloat_1x1x3x2_Array) {
});
input->FillWithYX(input_xy);
- Pad(AddParam(*input, &b), AddParam(*Literal::CreateR0<float>(1.5), &b),
+ Pad(AddParam(*input, &b), AddParam(*LiteralUtil::CreateR0<float>(1.5), &b),
r4_padding_on_dim0_dim1_);
auto expected = MakeUnique<Array4D<float>>(2, 3, 3, 2);
@@ -166,7 +167,8 @@ TEST_P(PadTestFloat, Pad4DFloatArrayWithInteriorPadding) {
const float pad_value = 1.5f;
Array4D<float> input(3, 2, 1, 1, {1, 2, 3, 4, 5, 6});
- Pad(AddParam(input, &b), AddParam(*Literal::CreateR0<float>(pad_value), &b),
+ Pad(AddParam(input, &b),
+ AddParam(*LiteralUtil::CreateR0<float>(pad_value), &b),
r4_padding_on_dim0_dim1_);
auto expected = MakeUnique<Array4D<float>>(8, 5, 1, 1);
@@ -205,11 +207,11 @@ TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstSmall) {
const float pad_value = -5.123f;
Array4D<float> input_array(1, 1, 2, 3, {1, 2, 3, 4, 5, 6});
- auto input = Literal::CreateR4FromArray4D<float>(input_array);
+ auto input = LiteralUtil::CreateR4FromArray4D<float>(input_array);
input = input->Relayout(layout);
- Pad(AddParam(*input, &b), AddParam(*Literal::CreateR0<float>(pad_value), &b),
- padding_config);
+ Pad(AddParam(*input, &b),
+ AddParam(*LiteralUtil::CreateR0<float>(pad_value), &b), padding_config);
Array4D<float> expected_array(1, 1, 5, 8);
expected_array.Fill(pad_value);
@@ -251,11 +253,11 @@ XLA_TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstNonTrivialMinorDimensions) {
input_array(0, 0, 0, 0) = 1.0f;
input_array(0, 24, 6, 6) = 2.0f;
input_array(0, 17, 2, 5) = 3.0f;
- auto input = Literal::CreateR4FromArray4D<float>(input_array);
+ auto input = LiteralUtil::CreateR4FromArray4D<float>(input_array);
input = input->Relayout(layout);
- Pad(AddParam(*input, &b), AddParam(*Literal::CreateR0<float>(pad_value), &b),
- padding_config);
+ Pad(AddParam(*input, &b),
+ AddParam(*LiteralUtil::CreateR0<float>(pad_value), &b), padding_config);
Array4D<float> expected_array(1, 25, 17, 11);
expected_array.Fill(pad_value);
@@ -329,7 +331,7 @@ XLA_TEST_P(PadTestFloat, Large2DPad) {
padding_config.mutable_dimensions(dim)->set_edge_padding_high(58 +
100 * dim);
}
- Pad(input, AddParam(*Literal::CreateR0<float>(0.0f), &b), padding_config);
+ Pad(input, AddParam(*LiteralUtil::CreateR0<float>(0.0f), &b), padding_config);
auto expected = ReferenceUtil::PadArray2D(*ones, padding_config, 0.0f);
ComputeAndCompareR2<float>(&b, *expected, {}, DefaultErrorSpec());
@@ -351,7 +353,8 @@ XLA_TEST_P(PadTestFloat, AllTypes2DPad) {
padding_config.mutable_dimensions(1)->set_edge_padding_low(6);
padding_config.mutable_dimensions(1)->set_edge_padding_high(4);
padding_config.mutable_dimensions(1)->set_interior_padding(2);
- Pad(input, AddParam(*Literal::CreateR0<float>(3.14f), &b), padding_config);
+ Pad(input, AddParam(*LiteralUtil::CreateR0<float>(3.14f), &b),
+ padding_config);
auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 3.14f);
ComputeAndCompareR2<float>(&b, *expected, {}, DefaultErrorSpec());
@@ -376,7 +379,8 @@ XLA_TEST_P(PadTestFloat, High2DPad) {
padding_config.mutable_dimensions(dim)->set_interior_padding(
interior_padding);
}
- Pad(input, AddParam(*Literal::CreateR0<float>(2.718f), &b), padding_config);
+ Pad(input, AddParam(*LiteralUtil::CreateR0<float>(2.718f), &b),
+ padding_config);
auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f);
@@ -403,7 +407,8 @@ XLA_TEST_P(PadTestFloat, NegativePadding2D) {
padding_config.mutable_dimensions(dim)->set_interior_padding(
interior_padding);
}
- Pad(input, AddParam(*Literal::CreateR0<float>(2.718f), &b), padding_config);
+ Pad(input, AddParam(*LiteralUtil::CreateR0<float>(2.718f), &b),
+ padding_config);
auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f);
@@ -430,7 +435,8 @@ XLA_TEST_P(PadTestFloat, NegativeAndInteriorPadding2D) {
padding_config.mutable_dimensions(dim)->set_interior_padding(
interior_padding[dim]);
}
- Pad(input, AddParam(*Literal::CreateR0<float>(2.718f), &b), padding_config);
+ Pad(input, AddParam(*LiteralUtil::CreateR0<float>(2.718f), &b),
+ padding_config);
auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f);
@@ -446,12 +452,13 @@ XLA_TEST_P(PadTestFloat, ReducePad) {
XlaComputation add = CreateScalarAddComputation(FloatType(), &b);
auto reduce =
- Reduce(input, AddParam(*Literal::CreateR0<float>(0.0), &b), add, {0});
+ Reduce(input, AddParam(*LiteralUtil::CreateR0<float>(0.0), &b), add, {0});
PaddingConfig padding_config = MakeNoPaddingConfig(3);
padding_config.mutable_dimensions(0)->set_edge_padding_low(1);
padding_config.mutable_dimensions(0)->set_edge_padding_high(1);
- Pad(reduce, AddParam(*Literal::CreateR0<float>(0.0f), &b), padding_config);
+ Pad(reduce, AddParam(*LiteralUtil::CreateR0<float>(0.0f), &b),
+ padding_config);
Array3D<float> expected({{{0.0, 0.0}, {0.0, 0.0}},
{{2.0, 2.0}, {2.0, 2.0}},
diff --git a/tensorflow/compiler/xla/tests/params_test.cc b/tensorflow/compiler/xla/tests/params_test.cc
index 2620063aa4..8ba1d11b33 100644
--- a/tensorflow/compiler/xla/tests/params_test.cc
+++ b/tensorflow/compiler/xla/tests/params_test.cc
@@ -24,7 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
@@ -42,7 +42,8 @@ class ParamsTest : public ClientLibraryTestBase {};
XLA_TEST_F(ParamsTest, ConstantR0F32Param) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = Literal::CreateR0<float>(3.14159f);
+ std::unique_ptr<Literal> param0_literal =
+ LiteralUtil::CreateR0<float>(3.14159f);
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
@@ -54,7 +55,7 @@ XLA_TEST_F(ParamsTest, ConstantR0F32Param) {
XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = Literal::CreateR1<float>({});
+ std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR1<float>({});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
@@ -67,7 +68,7 @@ XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) {
XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) {
XlaBuilder builder(TestName());
std::unique_ptr<Literal> param0_literal =
- Literal::CreateR1<float>({3.14f, -100.25f});
+ LiteralUtil::CreateR1<float>({3.14f, -100.25f});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
@@ -80,7 +81,7 @@ XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) {
XLA_TEST_F(ParamsTest, ConstantR1U8Param) {
XlaBuilder builder(TestName());
string str("hello world");
- std::unique_ptr<Literal> param0_literal = Literal::CreateR1U8(str);
+ std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR1U8(str);
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
@@ -94,7 +95,7 @@ XLA_TEST_F(ParamsTest, ConstantR1U8Param) {
XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) {
XlaBuilder builder(TestName());
std::unique_ptr<Literal> param0_literal =
- Literal::CreateR2FromArray2D<float>(Array2D<float>(3, 0));
+ LiteralUtil::CreateR2FromArray2D<float>(Array2D<float>(3, 0));
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
@@ -106,7 +107,7 @@ XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) {
XLA_TEST_F(ParamsTest, ConstantR2F32Param) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = Literal::CreateR2<float>(
+ std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR2<float>(
{{3.14f, -100.25f}, {7e8f, 7e-9f}, {30.3f, -100.0f}});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
@@ -122,12 +123,12 @@ XLA_TEST_F(ParamsTest, ConstantR2F32Param) {
XLA_TEST_F(ParamsTest, TwoParameters) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>({1, 2});
+ std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>({1, 2});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*literal0).ConsumeValueOrDie();
auto param0 = Parameter(&builder, 0, literal0->shape(), "param0");
- std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>({10, 20});
+ std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>({10, 20});
std::unique_ptr<GlobalData> param1_data =
client_->TransferToServer(*literal1).ConsumeValueOrDie();
auto param1 = Parameter(&builder, 1, literal1->shape(), "param1");
@@ -153,7 +154,7 @@ XLA_TEST_F(ParamsTest, TwoParameters) {
XLA_TEST_F(ParamsTest, MissingParameter) {
// Test that an error is returned when a computation with an incomplete set of
// parameters (parameter numbers not contiguous from 0) is executed.
- std::unique_ptr<Literal> literal = Literal::CreateR0<float>(3.14159f);
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR0<float>(3.14159f);
std::unique_ptr<GlobalData> data =
client_->TransferToServer(*literal).ConsumeValueOrDie();
@@ -167,12 +168,12 @@ XLA_TEST_F(ParamsTest, MissingParameter) {
XLA_TEST_F(ParamsTest, UnusedParameter) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>({1, 2});
+ std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>({1, 2});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*literal0).ConsumeValueOrDie();
Parameter(&builder, 0, literal0->shape(), "param0");
- std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>({10, 20});
+ std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>({10, 20});
std::unique_ptr<GlobalData> param1_data =
client_->TransferToServer(*literal1).ConsumeValueOrDie();
Parameter(&builder, 1, literal1->shape(), "param1");
@@ -187,11 +188,12 @@ XLA_TEST_F(ParamsTest, UnusedParametersInUnusedExpression) {
// unused expression.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>({1, 2});
+ std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>({1, 2});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>({10, 20, 30});
+ std::unique_ptr<Literal> literal1 =
+ LiteralUtil::CreateR1<float>({10, 20, 30});
std::unique_ptr<GlobalData> param1_data =
client_->TransferToServer(*literal1).ConsumeValueOrDie();
@@ -231,7 +233,7 @@ XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) {
std::vector<float> sum_value = {{entry0, entry1}};
sum_value.resize(size);
- std::unique_ptr<Literal> literal = Literal::CreateR1<float>(sum_value);
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<float>(sum_value);
param_data_owner.push_back(
client_->TransferToServer(*literal).ConsumeValueOrDie());
XlaOp param = Parameter(&builder, i, literal->shape(), "param");
@@ -266,7 +268,7 @@ XLA_TEST_F(ParamsTest,
constexpr int kParamCount = 3000;
for (int i = 0; i < kParamCount; ++i) {
target += i;
- std::unique_ptr<Literal> literal = Literal::CreateR0<float>(i);
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR0<float>(i);
param_data_owner.push_back(
std::move(client_->TransferToServer(*literal)).ValueOrDie());
XlaOp param = Parameter(&builder, i, literal->shape(), "param");
@@ -298,7 +300,7 @@ XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU(
std::vector<XlaOp> params;
for (int i = 0; i < kParamCount; ++i) {
target += i;
- std::unique_ptr<Literal> literal = Literal::CreateR1<int32>({i, i});
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<int32>({i, i});
param_data_owner.push_back(
std::move(client_->TransferToServer(*literal)).ValueOrDie());
XlaOp param = Parameter(&builder, i, literal->shape(), "param");
@@ -322,10 +324,10 @@ XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU(
std::vector<std::unique_ptr<Literal>> elements;
std::vector<const Literal*> ptrs;
for (int i = 0; i < kParamCount; ++i) {
- elements.push_back(Literal::CreateR1<int32>({target + i, target + i}));
+ elements.push_back(LiteralUtil::CreateR1<int32>({target + i, target + i}));
ptrs.push_back(elements.back().get());
}
- ComputeAndCompareTuple(&builder, *Literal::MakeTuple(ptrs), param_data);
+ ComputeAndCompareTuple(&builder, *LiteralUtil::MakeTuple(ptrs), param_data);
}
// Test large number of parameters flowing into a while-loop.
@@ -354,7 +356,7 @@ XLA_TEST_F(ParamsTest,
std::vector<XlaOp> params;
std::vector<Shape> parameter_shapes;
for (int i = 0; i < kParamCount; ++i) {
- std::unique_ptr<Literal> literal = Literal::CreateR1<int32>({i, i});
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<int32>({i, i});
param_data_owner.push_back(
std::move(client_->TransferToServer(*literal)).ValueOrDie());
XlaOp param = Parameter(&builder, i, literal->shape(), "param");
@@ -364,7 +366,7 @@ XLA_TEST_F(ParamsTest,
// Add bool parameter for the loop condition. Use a parameter HLO instead of a
// constant because DCE may eliminate the while-body otherwise.
- std::unique_ptr<Literal> bool_literal = Literal::CreateR0<bool>(false);
+ std::unique_ptr<Literal> bool_literal = LiteralUtil::CreateR0<bool>(false);
param_data_owner.push_back(
std::move(client_->TransferToServer(*bool_literal)).ValueOrDie());
XlaOp bool_param =
@@ -421,10 +423,10 @@ XLA_TEST_F(ParamsTest,
std::vector<std::unique_ptr<Literal>> elements;
std::vector<const Literal*> ptrs;
for (int i = 0; i < kParamCount; ++i) {
- elements.push_back(Literal::CreateR1<int32>({i, i}));
+ elements.push_back(LiteralUtil::CreateR1<int32>({i, i}));
ptrs.push_back(elements.back().get());
}
- ComputeAndCompareTuple(&builder, *Literal::MakeTuple(ptrs), param_data);
+ ComputeAndCompareTuple(&builder, *LiteralUtil::MakeTuple(ptrs), param_data);
}
#endif
@@ -441,9 +443,9 @@ XLA_TEST_F(ParamsTest, TupleOfR1ParametersAddedTogether) {
std::unique_ptr<GlobalData> data =
client_
- ->TransferToServer(*Literal::MakeTuple({
- Literal::CreateR1<float>({1, 2, 3}).get(),
- Literal::CreateR1<float>({4, 5, 6}).get(),
+ ->TransferToServer(*LiteralUtil::MakeTuple({
+ LiteralUtil::CreateR1<float>({1, 2, 3}).get(),
+ LiteralUtil::CreateR1<float>({4, 5, 6}).get(),
}))
.ConsumeValueOrDie();
@@ -455,7 +457,7 @@ XLA_TEST_F(ParamsTest, TupleOfR1ParametersAddedTogether) {
// Verifies that passing a 2x2 with {0, 1} layout returns the same value back
// when (transferred to the server and) passed through a parameter.
XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) {
- std::unique_ptr<Literal> literal = Literal::CreateR2WithLayout<float>(
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR2WithLayout<float>(
{{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({0, 1}));
XlaBuilder builder(TestName());
Parameter(&builder, 0, literal->shape(), "input");
@@ -467,7 +469,7 @@ XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) {
// As above, but for {1, 0} layout.
XLA_TEST_F(ParamsTest, R2_2x2_Layout_10) {
- std::unique_ptr<Literal> literal = Literal::CreateR2WithLayout<float>(
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR2WithLayout<float>(
{{1, 3}, {2, 4}}, LayoutUtil::MakeLayout({1, 0}));
XlaBuilder builder(TestName());
Parameter(&builder, 0, literal->shape(), "input");
@@ -478,7 +480,7 @@ XLA_TEST_F(ParamsTest, R2_2x2_Layout_10) {
}
XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) {
- std::unique_ptr<Literal> literal = Literal::CreateR2<float>({
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR2<float>({
{1, 3},
{2, 4},
});
diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc
index 8e163e885d..5ebf8344d2 100644
--- a/tensorflow/compiler/xla/tests/prng_test.cc
+++ b/tensorflow/compiler/xla/tests/prng_test.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
@@ -193,7 +193,7 @@ XLA_TEST_F(PrngTest, MapUsingRng) {
XlaBuilder builder(TestName());
std::unique_ptr<Literal> param0_literal =
- Literal::CreateR1<float>({2.2f, 5.3f, 4.4f, 5.5f});
+ LiteralUtil::CreateR1<float>({2.2f, 5.3f, 4.4f, 5.5f});
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> param0_data,
client_->TransferToServer(*param0_literal));
diff --git a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc
index 9052b188ed..a080dd1732 100644
--- a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc
@@ -95,21 +95,21 @@ XLA_TEST_P(ReduceWithLayoutTest, DISABLED_ON_GPU(Reduce)) {
*reduce_input_shape->mutable_layout() =
LayoutUtil::MakeLayout(reduce_layout.input_minor_to_major);
- std::unique_ptr<Literal> reduce_input =
- Literal::CreateR4<float>({{ /*i0=0*/
- {/*i1=0*/
- {-0.246092796, -0.179497838, -0.161181688},
- {-0.151643038, -0.240213156, -0.198156}},
- {/*i1=1*/
- {-0.14222312, -0.162200093, -0.193907976},
- {-0.239411, -0.198166847, -0.172471642}}},
- { /*i0=1*/
- {/*i1=0*/
- {-0.22965157, -0.218723893, -0.129257083},
- {-0.188762426, -0.16123569, -0.181166649}},
- {/*i1=1*/
- {-0.241772294, -0.245131493, -0.160247207},
- {-0.179881215, -0.23383224, -0.121976733}}}});
+ std::unique_ptr<Literal> reduce_input = LiteralUtil::CreateR4<float>(
+ {{ /*i0=0*/
+ {/*i1=0*/
+ {-0.246092796, -0.179497838, -0.161181688},
+ {-0.151643038, -0.240213156, -0.198156}},
+ {/*i1=1*/
+ {-0.14222312, -0.162200093, -0.193907976},
+ {-0.239411, -0.198166847, -0.172471642}}},
+ { /*i0=1*/
+ {/*i1=0*/
+ {-0.22965157, -0.218723893, -0.129257083},
+ {-0.188762426, -0.16123569, -0.181166649}},
+ {/*i1=1*/
+ {-0.241772294, -0.245131493, -0.160247207},
+ {-0.179881215, -0.23383224, -0.121976733}}}});
EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec(1e-5)));
}
diff --git a/tensorflow/compiler/xla/tests/reduce_precision_test.cc b/tensorflow/compiler/xla/tests/reduce_precision_test.cc
index 4c1aa12106..04c7f31646 100644
--- a/tensorflow/compiler/xla/tests/reduce_precision_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_precision_test.cc
@@ -24,7 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
@@ -230,7 +230,8 @@ XLA_TEST_P(ReducePrecisionAccuracyTest, ReducePrecisionF32) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> a_literal = Literal::CreateR1<float>({input_values});
+ std::unique_ptr<Literal> a_literal =
+ LiteralUtil::CreateR1<float>({input_values});
std::unique_ptr<GlobalData> a_data =
client_->TransferToServer(*a_literal).ConsumeValueOrDie();
auto a = Parameter(&builder, 0, a_literal->shape(), "a");
@@ -253,7 +254,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest,
DISABLED_ON_INTERPRETER(ReducePrecisionBeforeFusion)) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> a_literal = Literal::CreateR1<float>({1.00001});
+ std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR1<float>({1.00001});
std::unique_ptr<GlobalData> a_data =
client_->TransferToServer(*a_literal).ConsumeValueOrDie();
auto a = Parameter(&builder, 0, a_literal->shape(), "a");
@@ -282,7 +283,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest,
DISABLED_ON_INTERPRETER(ReducePrecisionSkippedAfterFusion)) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> a_literal = Literal::CreateR1<float>({1.00001});
+ std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR1<float>({1.00001});
std::unique_ptr<GlobalData> a_data =
client_->TransferToServer(*a_literal).ConsumeValueOrDie();
auto a = Parameter(&builder, 0, a_literal->shape(), "a");
@@ -308,7 +309,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest,
DISABLED_ON_INTERPRETER(ReducePrecisionAddedAfterFusion)) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> a_literal = Literal::CreateR1<float>({1.00001});
+ std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR1<float>({1.00001});
std::unique_ptr<GlobalData> a_data =
client_->TransferToServer(*a_literal).ConsumeValueOrDie();
auto a = Parameter(&builder, 0, a_literal->shape(), "a");
@@ -332,7 +333,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest,
DISABLED_ON_INTERPRETER(ReducePrecisionSkippedFusionContains)) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> a_literal = Literal::CreateR1<float>({1.00001});
+ std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR1<float>({1.00001});
std::unique_ptr<GlobalData> a_data =
client_->TransferToServer(*a_literal).ConsumeValueOrDie();
auto a = Parameter(&builder, 0, a_literal->shape(), "a");
@@ -357,7 +358,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest,
DISABLED_ON_INTERPRETER(ReducePrecisionAddedFusionContains)) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> a_literal = Literal::CreateR1<float>({1.00001});
+ std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR1<float>({1.00001});
std::unique_ptr<GlobalData> a_data =
client_->TransferToServer(*a_literal).ConsumeValueOrDie();
auto a = Parameter(&builder, 0, a_literal->shape(), "a");
diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc
index c9f57cbb16..1407fca72f 100644
--- a/tensorflow/compiler/xla/tests/reduce_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_test.cc
@@ -67,12 +67,12 @@ class ReduceTest : public ClientLibraryTestBase {
ReduceTest() {
// Implementation note: laid out z >> y >> x by default.
// clang-format off
- literal_2d_ = Literal::CreateR2<float>({
+ literal_2d_ = LiteralUtil::CreateR2<float>({
// x0 x1 x2
{ 1.f, 2.f, 3.f}, // y0
{ 4.f, 5.f, 6.f}, // y1
});
- literal_3d_ = Literal::CreateR3Projected<float>({
+ literal_3d_ = LiteralUtil::CreateR3Projected<float>({
// x0 x1 x2
{ 1.f, 2.f, 3.f}, // y0
{ 4.f, 5.f, 6.f}, // y1
@@ -101,7 +101,7 @@ class ReduceTest : public ClientLibraryTestBase {
}
}
std::unique_ptr<Literal> input_literal =
- Literal::CreateR1(AsSlice(input_data));
+ LiteralUtil::CreateR1(AsSlice(input_data));
std::unique_ptr<GlobalData> input_global_data =
client_->TransferToServer(*input_literal).ConsumeValueOrDie();
@@ -133,7 +133,7 @@ class ReduceTest : public ClientLibraryTestBase {
Reduce(pred_values, init_value, reduce,
/*dimensions_to_reduce=*/{0});
- std::unique_ptr<Literal> input_literal = Literal::CreateR1(input_data);
+ std::unique_ptr<Literal> input_literal = LiteralUtil::CreateR1(input_data);
std::unique_ptr<GlobalData> input_global_data =
client_->TransferToServer(*input_literal).ConsumeValueOrDie();
@@ -175,7 +175,7 @@ class ReduceTest : public ClientLibraryTestBase {
Array2D<uint8> input_data(rows, cols);
input_data.FillRandom(0, 1);
std::unique_ptr<Literal> input_literal =
- Literal::CreateR2FromArray2D(input_data);
+ LiteralUtil::CreateR2FromArray2D(input_data);
input_literal =
input_literal->Relayout(LayoutUtil::MakeLayout({minor, major}));
std::unique_ptr<GlobalData> input_global_data =
@@ -209,7 +209,7 @@ class ReduceTest : public ClientLibraryTestBase {
Array2D<float> input_data(rows, cols);
input_data.FillRandom(3.14f, 0.04);
std::unique_ptr<Literal> input_literal =
- Literal::CreateR2FromArray2D(input_data);
+ LiteralUtil::CreateR2FromArray2D(input_data);
input_literal =
input_literal->Relayout(LayoutUtil::MakeLayout({minor, major}));
std::unique_ptr<GlobalData> input_global_data =
@@ -237,7 +237,7 @@ class ReduceTest : public ClientLibraryTestBase {
Array2D<float> input_data(rows, cols);
input_data.FillRandom(3.14f, 0.04);
std::unique_ptr<Literal> input_literal =
- Literal::CreateR2FromArray2D(input_data);
+ LiteralUtil::CreateR2FromArray2D(input_data);
input_literal =
input_literal->Relayout(LayoutUtil::MakeLayout({minor, major}));
std::unique_ptr<GlobalData> input_global_data =
@@ -295,7 +295,7 @@ class ReduceTest : public ClientLibraryTestBase {
Array2D<NativeT> input_data(rows, cols);
input_data.FillUnique(initial_value);
std::unique_ptr<Literal> input_literal =
- Literal::CreateR2FromArray2D(input_data);
+ LiteralUtil::CreateR2FromArray2D(input_data);
input_literal =
input_literal->Relayout(LayoutUtil::MakeLayout({minor, major}));
std::unique_ptr<GlobalData> input_global_data =
@@ -450,7 +450,7 @@ XLA_TEST_F(ReduceTest, ReduceElementwiseR2_111x50_To_R1) {
Array2D<float> input_data(rows, cols);
input_data.FillRandom(3.14f, 0.04);
std::unique_ptr<Literal> input_literal =
- Literal::CreateR2FromArray2D(input_data);
+ LiteralUtil::CreateR2FromArray2D(input_data);
input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({0, 1}));
std::unique_ptr<GlobalData> input_global_data =
client_->TransferToServer(*input_literal).ConsumeValueOrDie();
@@ -482,7 +482,7 @@ XLA_TEST_F(ReduceTest, TransposeAndReduceElementwiseR2_111x50_To_R1) {
Array2D<float> input_data(rows, cols);
input_data.FillRandom(3.14f, 0.04);
std::unique_ptr<Literal> input_literal =
- Literal::CreateR2FromArray2D(input_data);
+ LiteralUtil::CreateR2FromArray2D(input_data);
input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({0, 1}));
std::unique_ptr<GlobalData> input_global_data =
client_->TransferToServer(*input_literal).ConsumeValueOrDie();
@@ -531,7 +531,7 @@ XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) {
Array3D<float> input_data(rows, 2, cols / 2);
input_data.FillRandom(3.14f, 0.04);
std::unique_ptr<Literal> input_literal =
- Literal::CreateR3FromArray3D(input_data);
+ LiteralUtil::CreateR3FromArray3D(input_data);
std::unique_ptr<GlobalData> input_global_data =
client_->TransferToServer(*input_literal).ConsumeValueOrDie();
@@ -594,7 +594,7 @@ XLA_TEST_F(ReduceTest, MaxReduce2DToR0) {
auto max = CreateScalarMaxComputation(F32, &builder);
Array2D<float> input(300, 250);
input.FillRandom(214.0f);
- auto input_literal = Literal::CreateR2FromArray2D(input);
+ auto input_literal = LiteralUtil::CreateR2FromArray2D(input);
Reduce(ConstantLiteral(&builder, *input_literal),
ConstantR0<float>(&builder, FLT_MIN), max, {0, 1});
auto input_max = FLT_MIN;
@@ -609,7 +609,7 @@ XLA_TEST_F(ReduceTest, MinReduce2DToR0) {
auto min = CreateScalarMinComputation(F32, &builder);
Array2D<float> input(150, 130);
input.FillRandom(214.0f);
- auto input_literal = Literal::CreateR2FromArray2D(input);
+ auto input_literal = LiteralUtil::CreateR2FromArray2D(input);
Reduce(ConstantLiteral(&builder, *input_literal),
ConstantR0<float>(&builder, FLT_MAX), min, {0, 1});
@@ -623,7 +623,7 @@ XLA_TEST_F(ReduceTest, UnsignedInt_MinReduce) {
XlaBuilder builder(TestName());
Array2D<uint32> input({{1}, {2}});
auto min = CreateScalarMinComputation(U32, &builder);
- auto input_literal = Literal::CreateR2FromArray2D(input);
+ auto input_literal = LiteralUtil::CreateR2FromArray2D(input);
auto initial_value =
ConstantR0<uint32>(&builder, std::numeric_limits<uint32>::max());
@@ -635,7 +635,7 @@ XLA_TEST_F(ReduceTest, UnsignedInt_MaxReduce) {
XlaBuilder builder(TestName());
Array2D<uint32> input({{1}, {2}});
auto max = CreateScalarMaxComputation(U32, &builder);
- auto input_literal = Literal::CreateR2FromArray2D(input);
+ auto input_literal = LiteralUtil::CreateR2FromArray2D(input);
auto initial_value =
ConstantR0<uint32>(&builder, std::numeric_limits<uint32>::min());
@@ -818,7 +818,7 @@ XLA_TEST_P(ReduceR3ToR2Test, ReduceR3ToR2) {
// input_array.FillRandom(3.14f, 0.05);
input_array.Fill(1.0f);
- auto input_literal = Literal::CreateR3FromArray3D(input_array);
+ auto input_literal = LiteralUtil::CreateR3FromArray3D(input_array);
input_literal =
input_literal->Relayout(LayoutUtil::MakeLayout(GetParam().layout));
std::unique_ptr<GlobalData> input_data =
@@ -872,7 +872,8 @@ XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(OperationOnConstantAsInitValue)) {
auto a = ConstantR0<float>(&builder, 2.0f);
auto a2 = Abs(a);
- std::unique_ptr<Literal> b_literal = Literal::CreateR1<float>({1.0f, 4.0f});
+ std::unique_ptr<Literal> b_literal =
+ LiteralUtil::CreateR1<float>({1.0f, 4.0f});
std::unique_ptr<GlobalData> b_data =
client_->TransferToServer(*b_literal).ConsumeValueOrDie();
auto b = Parameter(&builder, 0, b_literal->shape(), "b");
@@ -900,7 +901,7 @@ class ReduceInitializerTest : public ReduceTest {
auto init = ConstantR0<T>(&builder, initializer);
std::vector<T> input_arr(num_elems, std::numeric_limits<T>::lowest());
- auto input_literal = Literal::CreateR1<T>(input_arr);
+ auto input_literal = LiteralUtil::CreateR1<T>(input_arr);
auto input_data =
client_->TransferToServer(*input_literal).ConsumeValueOrDie();
Reduce(Parameter(&builder, 0, input_literal->shape(), "input"), init,
@@ -950,10 +951,11 @@ XLA_TEST_F(ReduceTest, ReduceIdentity) {
float operand[] = {42.0f};
float init = 58.5f;
float expected = 42.0f;
- std::unique_ptr<Literal> input_literal = Literal::CreateR1<float>(operand);
+ std::unique_ptr<Literal> input_literal =
+ LiteralUtil::CreateR1<float>(operand);
std::unique_ptr<GlobalData> input_global_data =
client_->TransferToServer(*input_literal).ConsumeValueOrDie();
- std::unique_ptr<Literal> input_literal2 = Literal::CreateR0<float>(init);
+ std::unique_ptr<Literal> input_literal2 = LiteralUtil::CreateR0<float>(init);
std::unique_ptr<GlobalData> input_global_data2 =
client_->TransferToServer(*input_literal2).ConsumeValueOrDie();
ComputeAndCompareR0<float>(
diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc
index 741974480c..c2681f70f7 100644
--- a/tensorflow/compiler/xla/tests/reduce_window_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc
@@ -70,8 +70,8 @@ class ReduceWindowTest : public ::testing::WithParamInterface<bool>,
tensorflow::gtl::ArraySlice<int64> window_dimensions,
tensorflow::gtl::ArraySlice<int64> window_strides,
Padding padding) {
- auto init =
- CreateConstantFromLiteral(*Literal::CreateR0<float>(0.0f), &builder_);
+ auto init = CreateConstantFromLiteral(*LiteralUtil::CreateR0<float>(0.0f),
+ &builder_);
ReduceWindow(input, init,
CreateScalarAddComputation(FloatType(), &builder_),
window_dimensions, window_strides, padding);
@@ -81,7 +81,8 @@ class ReduceWindowTest : public ::testing::WithParamInterface<bool>,
tensorflow::gtl::ArraySlice<int64> window_dimensions,
tensorflow::gtl::ArraySlice<int64> window_strides,
Padding padding) {
- auto init = CreateConstantFromLiteral(Literal::MinValue(F32), &builder_);
+ auto init =
+ CreateConstantFromLiteral(LiteralUtil::MinValue(F32), &builder_);
ReduceWindow(input, init,
CreateScalarMaxComputation(FloatType(), &builder_),
window_dimensions, window_strides, padding);
@@ -91,7 +92,8 @@ class ReduceWindowTest : public ::testing::WithParamInterface<bool>,
tensorflow::gtl::ArraySlice<int64> window_dimensions,
tensorflow::gtl::ArraySlice<int64> window_strides,
Padding padding) {
- auto init = CreateConstantFromLiteral(Literal::MaxValue(F32), &builder_);
+ auto init =
+ CreateConstantFromLiteral(LiteralUtil::MaxValue(F32), &builder_);
ReduceWindow(input, init,
CreateScalarMinComputation(FloatType(), &builder_),
window_dimensions, window_strides, padding);
@@ -102,9 +104,9 @@ class ReduceWindowTest : public ::testing::WithParamInterface<bool>,
TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) {
const auto input = CreateConstantFromLiteral(
- *Literal::CreateR1<float>({1, 1, 1, 1}), &builder_);
+ *LiteralUtil::CreateR1<float>({1, 1, 1, 1}), &builder_);
const auto init_value =
- CreateConstantFromLiteral(*Literal::CreateR0<float>(0), &builder_);
+ CreateConstantFromLiteral(*LiteralUtil::CreateR0<float>(0), &builder_);
TF_ASSERT_OK(builder_.first_error());
ReduceWindow(input, init_value,
CreateScalarAddComputation(FloatType(), &builder_),
@@ -119,32 +121,32 @@ TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) {
// Regression test for b/68964348.
TEST_P(ReduceWindowTest, R0ReduceWindow) {
const auto input =
- CreateConstantFromLiteral(*Literal::CreateR0<float>(42.0), &builder_);
+ CreateConstantFromLiteral(*LiteralUtil::CreateR0<float>(42.0), &builder_);
const auto init =
- CreateConstantFromLiteral(*Literal::CreateR0<float>(1.0), &builder_);
+ CreateConstantFromLiteral(*LiteralUtil::CreateR0<float>(1.0), &builder_);
ReduceWindow(input, init, CreateScalarAddComputation(FloatType(), &builder_),
/*window_dimensions=*/{},
/*window_strides=*/{}, Padding::kSame);
- ComputeAndCompareLiteral(&builder_, *Literal::CreateR0<float>(43.0), {},
+ ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR0<float>(43.0), {},
ErrorSpec(0.00001));
}
TEST_P(ReduceWindowTest, Min3In5Stride2) {
const auto input = CreateConstantFromLiteral(
- *Literal::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_);
+ *LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_);
ReduceWindowMin(input, {3}, {2}, Padding::kValid);
- ComputeAndCompareLiteral(&builder_, *Literal::CreateR1<float>({100, 1}), {},
- ErrorSpec(0.00001));
+ ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR1<float>({100, 1}),
+ {}, ErrorSpec(0.00001));
}
TEST_P(ReduceWindowTest, Min3In5Stride1WithSamePadding) {
const auto input = CreateConstantFromLiteral(
- *Literal::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_);
+ *LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_);
ReduceWindowMin(input, /*window_dimensions=*/{3}, /*window_strides=*/{1},
Padding::kSame);
ComputeAndCompareLiteral(&builder_,
- *Literal::CreateR1<float>({1000, 100, 10, 1, 1}), {},
- ErrorSpec(0.00001));
+ *LiteralUtil::CreateR1<float>({1000, 100, 10, 1, 1}),
+ {}, ErrorSpec(0.00001));
}
XLA_TEST_P(ReduceWindowTest, ZeroElementSmall) {
@@ -156,7 +158,7 @@ XLA_TEST_P(ReduceWindowTest, ZeroElementSmall) {
auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1},
{1, 1, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {},
+ ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {},
DefaultErrorSpec());
}
@@ -171,7 +173,7 @@ TEST_P(ReduceWindowTest, NonSquareSmall) {
auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1},
{1, 1, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {},
+ ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {},
DefaultErrorSpec());
}
@@ -185,7 +187,7 @@ TEST_P(ReduceWindowTest, MiddleDimsSmall) {
auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 1, 1},
{1, 2, 2, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {},
+ ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {},
DefaultErrorSpec());
}
@@ -202,7 +204,7 @@ TEST_P(ReduceWindowTest, Along2ndMinorDim) {
auto res = ReferenceUtil::ReduceWindow4DAdd(
input_array, 0.0f, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {},
+ ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {},
DefaultErrorSpec());
}
@@ -224,8 +226,8 @@ TEST_P(ReduceWindowTest, AmongMajor2Dims) {
input_array, 0.0f, {win_len, win_len, 1, 1},
{win_stride, win_stride, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {},
- DefaultErrorSpec());
+ ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result),
+ {}, DefaultErrorSpec());
}
TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) {
@@ -247,8 +249,8 @@ TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) {
input_array, 0.0f, {win_len, win_len, 1, 1},
{win_stride, win_stride, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {},
- DefaultErrorSpec());
+ ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result),
+ {}, DefaultErrorSpec());
}
// Tests the super windowing logic w.r.t handling prime number of windows in a
@@ -272,8 +274,8 @@ TEST_P(ReduceWindowTest, PrimeWindowsInReductionDimension) {
input_array, 0.0f, {win_len, win_len, 1, 1},
{win_stride, win_stride, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {},
- DefaultErrorSpec());
+ ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result),
+ {}, DefaultErrorSpec());
}
TEST_P(ReduceWindowTest, ReduceAlongLaneDimension) {
@@ -289,8 +291,8 @@ TEST_P(ReduceWindowTest, ReduceAlongLaneDimension) {
auto result = ReferenceUtil::ReduceWindow4DAdd(
input_array, 0.0f, {1, 1, 1, 11}, {1, 1, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {},
- DefaultErrorSpec());
+ ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result),
+ {}, DefaultErrorSpec());
}
// Tests a reduction function that is not a simple add/min/max/etc.
@@ -308,12 +310,12 @@ XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) {
auto lhs = Parameter(b.get(), 0, scalar, "lhs");
auto rhs = Parameter(b.get(), 1, scalar, "rhs");
Min(Add(lhs, rhs),
- CreateConstantFromLiteral(*Literal::CreateR0<float>(8.0f), b.get()));
+ CreateConstantFromLiteral(*LiteralUtil::CreateR0<float>(8.0f), b.get()));
XlaComputation reduce_fn = b->BuildAndNoteError();
ReduceWindow(
input,
- CreateConstantFromLiteral(*Literal::CreateR0<float>(0.0f), &builder_),
+ CreateConstantFromLiteral(*LiteralUtil::CreateR0<float>(0.0f), &builder_),
reduce_fn,
/*window_dimensions=*/{1, 1, 2, 1},
/*window_strides=*/{1, 1, 1, 1}, padding);
@@ -327,15 +329,15 @@ XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) {
/*window=*/{1, 1, 2, 1},
/*stride=*/{1, 1, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*expected), {},
- DefaultErrorSpec());
+ ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*expected),
+ {}, DefaultErrorSpec());
}
TEST_P(ReduceWindowTest, R4UnitWindow) {
Array4D<float> input_array(13, 12, 8, 15);
input_array.FillRandom(2.f, 2.f);
std::unique_ptr<Literal> input_literal =
- Literal::CreateR4FromArray4DWithLayout(
+ LiteralUtil::CreateR4FromArray4DWithLayout(
input_array, LayoutUtil::MakeLayout({0, 3, 2, 1}));
XlaOp input;
auto input_data = CreateParameterAndTransferLiteral(
@@ -347,7 +349,7 @@ TEST_P(ReduceWindowTest, R4UnitWindow) {
auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 7, 1},
{1, 4, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res),
+ ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res),
{input_data.get()}, DefaultErrorSpec());
}
@@ -376,7 +378,7 @@ XLA_TEST_P(ReduceWindowTest, R6Add) {
auto shape = ShapeUtil::MakeShape(F32, input_dims);
std::unique_ptr<Literal> arg_literal =
- Literal::CreateFullWithDescendingLayout<float>(input_dims, 1.0f);
+ LiteralUtil::CreateFullWithDescendingLayout<float>(input_dims, 1.0f);
const auto input = CreateConstantFromLiteral(*arg_literal, &builder_);
@@ -385,7 +387,7 @@ XLA_TEST_P(ReduceWindowTest, R6Add) {
std::vector<int64> output_dims = {8, 8, 6, 6, 8, 8};
std::unique_ptr<Literal> expected =
- Literal::CreateFullWithDescendingLayout<float>(output_dims, 9.0f);
+ LiteralUtil::CreateFullWithDescendingLayout<float>(output_dims, 9.0f);
ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec());
}
@@ -394,7 +396,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) {
Array4D<float> input_array(2, 1, 27, 119);
input_array.FillRandom(2.0f);
std::unique_ptr<Literal> input_literal =
- Literal::CreateR4FromArray4DWithLayout(
+ LiteralUtil::CreateR4FromArray4DWithLayout(
input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaOp input;
auto input_data = CreateParameterAndTransferLiteral(
@@ -408,7 +410,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) {
auto res = ReferenceUtil::ReduceWindow4DAdd(
input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res),
+ ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res),
{input_data.get()}, DefaultErrorSpec());
}
@@ -416,7 +418,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) {
Array4D<float> input_array(3, 2, 4, 64);
input_array.FillRandom(2.0f);
std::unique_ptr<Literal> input_literal =
- Literal::CreateR4FromArray4DWithLayout(
+ LiteralUtil::CreateR4FromArray4DWithLayout(
input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaOp input;
auto input_data = CreateParameterAndTransferLiteral(
@@ -430,7 +432,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) {
auto res = ReferenceUtil::ReduceWindow4DAdd(
input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res),
+ ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res),
{input_data.get()}, DefaultErrorSpec());
}
@@ -438,7 +440,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) {
Array4D<float> input_array(1, 3, 12, 200);
input_array.FillRandom(2.0f);
std::unique_ptr<Literal> input_literal =
- Literal::CreateR4FromArray4DWithLayout(
+ LiteralUtil::CreateR4FromArray4DWithLayout(
input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaOp input;
auto input_data = CreateParameterAndTransferLiteral(
@@ -452,7 +454,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) {
auto res = ReferenceUtil::ReduceWindow4DAdd(
input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res),
+ ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res),
{input_data.get()}, DefaultErrorSpec());
}
@@ -473,18 +475,18 @@ TEST_P(ReduceWindowTest, AmongMajor2DimsMultipleMinor) {
auto result = ReferenceUtil::ReduceWindow4DAdd(
input_array, 0.0f, {win_len, win_len, 1, 1},
{win_stride, win_stride, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {},
- DefaultErrorSpec());
+ ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result),
+ {}, DefaultErrorSpec());
}
XLA_TEST_P(ReduceWindowTest, Add24In1152_NoOverlap) {
std::vector<float> input_vector(128 * 9, 1);
const auto input = CreateConstantFromLiteral(
- *Literal::CreateR1<float>(input_vector), &builder_);
+ *LiteralUtil::CreateR1<float>(input_vector), &builder_);
ReduceWindowAdd(input, {32}, {128}, Padding::kValid);
ComputeAndCompareLiteral(
&builder_,
- *Literal::CreateR1<float>({32, 32, 32, 32, 32, 32, 32, 32, 32}), {},
+ *LiteralUtil::CreateR1<float>({32, 32, 32, 32, 32, 32, 32, 32, 32}), {},
DefaultErrorSpec());
}
@@ -499,9 +501,9 @@ XLA_TEST_P(ReduceWindowTest, Add128In128Stride128) {
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
const auto input = CreateConstantFromLiteral(
- *Literal::CreateR1<float>(input_vector), &builder_);
+ *LiteralUtil::CreateR1<float>(input_vector), &builder_);
ReduceWindowAdd(input, {128}, {128}, Padding::kValid);
- ComputeAndCompareLiteral(&builder_, *Literal::CreateR1<float>({1088}), {},
+ ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR1<float>({1088}), {},
DefaultErrorSpec());
}
@@ -516,9 +518,9 @@ XLA_TEST_P(ReduceWindowTest, Add128In128) {
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
const auto input = CreateConstantFromLiteral(
- *Literal::CreateR1<float>(input_vector), &builder_);
+ *LiteralUtil::CreateR1<float>(input_vector), &builder_);
ReduceWindowAdd(input, {128}, {1}, Padding::kValid);
- ComputeAndCompareLiteral(&builder_, *Literal::CreateR1<float>({1088}), {},
+ ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR1<float>({1088}), {},
DefaultErrorSpec());
}
@@ -535,14 +537,15 @@ TEST_P(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) {
auto res = ReferenceUtil::ReduceWindow2DAdd(
input_array, 0.0f, {win_len, win_len}, {stride, stride}, padding);
- ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray<float>(*res),
- {}, DefaultErrorSpec());
+ ComputeAndCompareLiteral(&builder_,
+ *LiteralUtil::CreateFromArray<float>(*res), {},
+ DefaultErrorSpec());
}
TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) {
Array2D<float> input_array(6, 4, 1.0f);
XlaOp input = Broadcast(
- CreateConstantFromLiteral(Literal::One(F32), &builder_), {6, 4});
+ CreateConstantFromLiteral(LiteralUtil::One(F32), &builder_), {6, 4});
Padding padding = Padding::kSame;
ReduceWindowAdd(input, {4, 2}, {3, 3}, padding);
@@ -550,8 +553,9 @@ TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) {
auto res = ReferenceUtil::ReduceWindow2DAdd(input_array, 0.0f, {4, 2}, {3, 3},
padding);
- ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray<float>(*res),
- {}, DefaultErrorSpec());
+ ComputeAndCompareLiteral(&builder_,
+ *LiteralUtil::CreateFromArray<float>(*res), {},
+ DefaultErrorSpec());
}
INSTANTIATE_TEST_CASE_P(ReduceWindowTestInstance, ReduceWindowTest,
@@ -609,7 +613,7 @@ class R4ReduceWindowTest : public ReduceWindowTestBase,
param.base_bounds[2], param.base_bounds[3]);
input.FillIota(1);
std::unique_ptr<Literal> input_literal =
- Literal::CreateR4FromArray4DWithLayout(
+ LiteralUtil::CreateR4FromArray4DWithLayout(
input, LayoutUtil::MakeLayout(param.layout));
XlaOp parameter;
auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0",
@@ -621,7 +625,7 @@ class R4ReduceWindowTest : public ReduceWindowTestBase,
}
auto init_value =
- CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b);
+ CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b);
CHECK(param.reducer == kAdd || param.reducer == kMax);
auto computation = param.reducer == kAdd
? CreateScalarAddComputation(FloatType(), &b)
@@ -647,7 +651,7 @@ class R4ReduceWindowTest : public ReduceWindowTestBase,
/*stride=*/param.strides,
/*padding=*/padding);
std::unique_ptr<Literal> expected_literal =
- Literal::CreateFromArray(*expected);
+ LiteralUtil::CreateFromArray(*expected);
const Shape& expected_shape_with_layout = ShapeUtil::MakeShapeWithLayout(
input_literal->shape().element_type(),
AsInt64Slice(expected_literal->shape().dimensions()), param.layout);
@@ -959,14 +963,14 @@ TEST_P(R3ReduceWindowTest, Add) {
Array3D<float> input(param.base_bounds[0], param.base_bounds[1],
param.base_bounds[2], 1.0f);
std::unique_ptr<Literal> input_literal =
- Literal::CreateR3FromArray3DWithLayout(
+ LiteralUtil::CreateR3FromArray3DWithLayout(
input, LayoutUtil::MakeLayout(param.layout));
XlaOp parameter;
auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0",
&b, &parameter);
auto init_value =
- CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b);
+ CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b);
ReduceWindow(/*operand=*/parameter,
/*init_value=*/init_value,
/*computation=*/CreateScalarAddComputation(FloatType(), &b),
@@ -977,7 +981,7 @@ TEST_P(R3ReduceWindowTest, Add) {
/*operand=*/input, /*init=*/kInitValue, /*window=*/param.window_bounds,
/*stride=*/param.strides, /*padding=*/param.padding);
- ComputeAndCompareLiteral(&b, *Literal::CreateFromArray(*expected),
+ ComputeAndCompareLiteral(&b, *LiteralUtil::CreateFromArray(*expected),
{input_arg.get()}, DefaultErrorSpec());
}
@@ -1093,7 +1097,7 @@ class R2ReduceWindowTest : public ReduceWindowTestBase,
const float kInitValue = 0.0f;
Array2D<float> input(param.base_bounds[0], param.base_bounds[1], 1.0f);
std::unique_ptr<Literal> input_literal =
- Literal::CreateR2FromArray2DWithLayout(
+ LiteralUtil::CreateR2FromArray2DWithLayout(
input, LayoutUtil::MakeLayout(param.layout));
XlaOp parameter;
@@ -1107,7 +1111,7 @@ class R2ReduceWindowTest : public ReduceWindowTestBase,
? CreateScalarAddComputation(FloatType(), &b)
: CreateScalarMaxComputation(FloatType(), &b);
auto init_value =
- CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b);
+ CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b);
ReduceWindowWithGeneralPadding(
/*operand=*/parameter,
/*init_value=*/init_value,
@@ -1123,7 +1127,7 @@ class R2ReduceWindowTest : public ReduceWindowTestBase,
/*window=*/param.window_bounds,
/*stride=*/param.strides, /*padding=*/padding);
- ComputeAndCompareLiteral(&b, *Literal::CreateFromArray(*expected),
+ ComputeAndCompareLiteral(&b, *LiteralUtil::CreateFromArray(*expected),
{input_arg.get()}, DefaultErrorSpec());
}
};
@@ -1292,7 +1296,7 @@ TEST_P(R1ReduceWindowTest, DoIt) {
std::vector<float> input_vector(param.base_bounds[0]);
std::iota(std::begin(input_vector), std::end(input_vector), 0);
std::unique_ptr<Literal> input_literal =
- Literal::CreateR1(tensorflow::gtl::ArraySlice<float>(input_vector));
+ LiteralUtil::CreateR1(tensorflow::gtl::ArraySlice<float>(input_vector));
XlaOp parameter;
auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0",
&b, &parameter);
@@ -1304,7 +1308,7 @@ TEST_P(R1ReduceWindowTest, DoIt) {
? CreateScalarAddComputation(FloatType(), &b)
: CreateScalarMaxComputation(FloatType(), &b);
auto init_value =
- CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b);
+ CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b);
ReduceWindowWithGeneralPadding(
/*operand=*/parameter,
/*init_value=*/init_value,
@@ -1323,7 +1327,7 @@ TEST_P(R1ReduceWindowTest, DoIt) {
/*stride=*/param.strides,
/*padding=*/padding);
- ComputeAndCompareLiteral(&b, *Literal::CreateR1<float>(*expected),
+ ComputeAndCompareLiteral(&b, *LiteralUtil::CreateR1<float>(*expected),
{input_arg.get()}, DefaultErrorSpec());
}
diff --git a/tensorflow/compiler/xla/tests/replay_test.cc b/tensorflow/compiler/xla/tests/replay_test.cc
index bebd814fa8..d544968648 100644
--- a/tensorflow/compiler/xla/tests/replay_test.cc
+++ b/tensorflow/compiler/xla/tests/replay_test.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -91,10 +91,10 @@ XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) {
// Run it.
std::unique_ptr<GlobalData> x_data =
- client_->TransferToServer(*Literal::CreateR0<int32>(2))
+ client_->TransferToServer(*LiteralUtil::CreateR0<int32>(2))
.ConsumeValueOrDie();
std::unique_ptr<GlobalData> y_data =
- client_->TransferToServer(*Literal::CreateR0<int32>(3))
+ client_->TransferToServer(*LiteralUtil::CreateR0<int32>(3))
.ConsumeValueOrDie();
std::unique_ptr<Literal> literal =
client_
diff --git a/tensorflow/compiler/xla/tests/reshape_motion_test.cc b/tensorflow/compiler/xla/tests/reshape_motion_test.cc
index 5812fe442b..7c0389cfa3 100644
--- a/tensorflow/compiler/xla/tests/reshape_motion_test.cc
+++ b/tensorflow/compiler/xla/tests/reshape_motion_test.cc
@@ -24,7 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc
index d3d6c3c7d7..46d91711a5 100644
--- a/tensorflow/compiler/xla/tests/reshape_test.cc
+++ b/tensorflow/compiler/xla/tests/reshape_test.cc
@@ -55,39 +55,39 @@ XLA_TEST_P(ReshapeTest, CollapseTrivial1x1) {
XlaBuilder builder(TestName());
Array2D<float> input_array(1, 1);
input_array.Fill(1.0f);
- auto input_literal = Literal::CreateR2FromArray2D(input_array);
+ auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array);
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter",
&builder, &parameter);
Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
- auto expected_literal = Literal::CreateR1<float>({1.0f});
+ auto expected_literal = LiteralUtil::CreateR1<float>({1.0f});
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
XLA_TEST_P(ReshapeTest, CollapseTrivialR1EmptyDims) {
XlaBuilder builder(TestName());
- auto input_literal = Literal::CreateR1<float>({1.0f});
+ auto input_literal = LiteralUtil::CreateR1<float>({1.0f});
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter",
&builder, &parameter);
Collapse(/*operand=*/parameter, /*dimensions=*/{});
- auto expected_literal = Literal::CreateR1<float>({1.0f});
+ auto expected_literal = LiteralUtil::CreateR1<float>({1.0f});
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
XLA_TEST_P(ReshapeTest, CollapseTrivialR1OnlyDim) {
XlaBuilder builder(TestName());
- auto input_literal = Literal::CreateR1<float>({1.0f});
+ auto input_literal = LiteralUtil::CreateR1<float>({1.0f});
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter",
&builder, &parameter);
Collapse(/*operand=*/parameter, /*dimensions=*/{0});
- auto expected_literal = Literal::CreateR1<float>({1.0f});
+ auto expected_literal = LiteralUtil::CreateR1<float>({1.0f});
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
@@ -97,7 +97,7 @@ XLA_TEST_P(ReshapeTest, SingleElementArrayToScalar) {
XlaBuilder builder(TestName());
Array2D<float> input_array(1, 1);
input_array.Fill(1.0f);
- auto input_literal = Literal::CreateR2FromArray2D(input_array);
+ auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array);
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter",
&builder, &parameter);
@@ -105,7 +105,7 @@ XLA_TEST_P(ReshapeTest, SingleElementArrayToScalar) {
/*new_sizes=*/{});
auto new_shape = builder.GetShape(reshape).ConsumeValueOrDie();
- auto expected_literal = Literal::CreateR0<float>(1.0f);
+ auto expected_literal = LiteralUtil::CreateR0<float>(1.0f);
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
@@ -113,14 +113,14 @@ XLA_TEST_P(ReshapeTest, SingleElementArrayToScalar) {
XLA_TEST_P(ReshapeTest, ScalarToSingleElementArray) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = Literal::CreateR0<float>(1.0f);
+ std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR0<float>(1.0f);
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0",
&builder, &parameter);
auto a = Neg(parameter);
Reshape(/*operand=*/a, /*dimensions=*/{}, /*new_sizes=*/{1});
- auto expected_literal = Literal::CreateR1<float>({-1.0f});
+ auto expected_literal = LiteralUtil::CreateR1<float>({-1.0f});
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
@@ -128,12 +128,12 @@ XLA_TEST_P(ReshapeTest, ScalarToSingleElementArray) {
XLA_TEST_P(ReshapeTest, Trivial0x3) {
XlaBuilder builder(TestName());
Array2D<float> input_array(0, 3);
- auto input_literal = Literal::CreateR2FromArray2D(input_array);
+ auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array);
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
- auto expected_literal = Literal::CreateR1<float>({});
+ auto expected_literal = LiteralUtil::CreateR1<float>({});
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
@@ -142,12 +142,12 @@ XLA_TEST_P(ReshapeTest, Trivial0x3WithParameter) {
XlaBuilder builder(TestName());
std::unique_ptr<Literal> param0_literal =
- Literal::CreateR2FromArray2D<float>(Array2D<float>(0, 3));
+ LiteralUtil::CreateR2FromArray2D<float>(Array2D<float>(0, 3));
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0",
&builder, &parameter);
Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
- auto expected_literal = Literal::CreateR1<float>({});
+ auto expected_literal = LiteralUtil::CreateR1<float>({});
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
@@ -155,12 +155,12 @@ XLA_TEST_P(ReshapeTest, Trivial0x3WithParameter) {
XLA_TEST_P(ReshapeTest, Trivial3x0) {
XlaBuilder builder(TestName());
Array2D<float> input_array(3, 0);
- auto input_literal = Literal::CreateR2FromArray2D(input_array);
+ auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array);
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
- auto expected_literal = Literal::CreateR1<float>({});
+ auto expected_literal = LiteralUtil::CreateR1<float>({});
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
@@ -168,12 +168,12 @@ XLA_TEST_P(ReshapeTest, Trivial3x0) {
// Collapses a 2-dimensional row vector to 1 dimension.
XLA_TEST_P(ReshapeTest, Trivial1x3) {
XlaBuilder builder(TestName());
- auto input_literal = Literal::CreateR2<float>({{1.0f, 2.0f, 3.0f}});
+ auto input_literal = LiteralUtil::CreateR2<float>({{1.0f, 2.0f, 3.0f}});
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
- auto expected_literal = Literal::CreateR1<float>({1.0f, 2.0f, 3.0f});
+ auto expected_literal = LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f});
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
@@ -181,12 +181,12 @@ XLA_TEST_P(ReshapeTest, Trivial1x3) {
// Collapses a 2-dimensional column vector to 1 dimension.
XLA_TEST_P(ReshapeTest, Trivial3x1) {
XlaBuilder builder(TestName());
- auto input_literal = Literal::CreateR2<float>({{1.0f}, {2.0f}, {3.0f}});
+ auto input_literal = LiteralUtil::CreateR2<float>({{1.0f}, {2.0f}, {3.0f}});
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
- auto expected_literal = Literal::CreateR1<float>({1.0f, 2.0f, 3.0f});
+ auto expected_literal = LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f});
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
@@ -194,13 +194,13 @@ XLA_TEST_P(ReshapeTest, Trivial3x1) {
// Splits an empty vector into an empty matrix.
XLA_TEST_P(ReshapeTest, R1ToR2_0_To_2x0) {
XlaBuilder builder(TestName());
- auto input_literal = Literal::CreateR1<float>({});
+ auto input_literal = LiteralUtil::CreateR1<float>({});
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{0},
/*new_sizes=*/{2, 0});
- auto expected_literal = Literal::CreateR2<float>({{}, {}});
+ auto expected_literal = LiteralUtil::CreateR2<float>({{}, {}});
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
@@ -209,14 +209,14 @@ XLA_TEST_P(ReshapeTest, R1ToR2_0_To_2x0) {
XLA_TEST_P(ReshapeTest, R1ToR2_6_To_2x3) {
XlaBuilder builder(TestName());
auto input_literal =
- Literal::CreateR1<float>({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f});
+ LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f});
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{0},
/*new_sizes=*/{2, 3});
auto expected_literal =
- Literal::CreateR2<float>({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}});
+ LiteralUtil::CreateR2<float>({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}});
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
@@ -224,13 +224,13 @@ XLA_TEST_P(ReshapeTest, R1ToR2_6_To_2x3) {
// Transposes a 2x0 array to a 0x2 array.
XLA_TEST_P(ReshapeTest, Reshape0x2To2x0) {
XlaBuilder builder(TestName());
- auto input_literal = Literal::CreateFromArray(Array2D<float>(0, 2));
+ auto input_literal = LiteralUtil::CreateFromArray(Array2D<float>(0, 2));
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1},
/*new_sizes=*/{2, 0});
- auto expected_literal = Literal::CreateR2<float>({{}, {}});
+ auto expected_literal = LiteralUtil::CreateR2<float>({{}, {}});
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
@@ -239,7 +239,7 @@ XLA_TEST_P(ReshapeTest, Reshape0x2To2x0) {
XLA_TEST_P(ReshapeTest, ReshapeRowToCol) {
XlaBuilder builder(TestName());
auto simple = MakeLinspaceArray2D(1.0f, 3.0f, 1, 3);
- auto input_literal = Literal::CreateFromArray(*simple);
+ auto input_literal = LiteralUtil::CreateFromArray(*simple);
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
@@ -247,7 +247,7 @@ XLA_TEST_P(ReshapeTest, ReshapeRowToCol) {
/*new_sizes=*/{3, 1});
auto expected = ReferenceUtil::TransposeArray2D(*simple);
- auto expected_literal = Literal::CreateFromArray(*expected);
+ auto expected_literal = LiteralUtil::CreateFromArray(*expected);
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
@@ -256,7 +256,7 @@ XLA_TEST_P(ReshapeTest, ReshapeRowToCol) {
XLA_TEST_P(ReshapeTest, TransposeAsReshape) {
XlaBuilder builder(TestName());
auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3);
- auto input_literal = Literal::CreateFromArray(*a4x3);
+ auto input_literal = LiteralUtil::CreateFromArray(*a4x3);
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
@@ -264,7 +264,7 @@ XLA_TEST_P(ReshapeTest, TransposeAsReshape) {
/*new_sizes=*/{3, 4});
auto expected = ReferenceUtil::TransposeArray2D(*a4x3);
- auto expected_literal = Literal::CreateFromArray(*expected);
+ auto expected_literal = LiteralUtil::CreateFromArray(*expected);
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
@@ -272,12 +272,12 @@ XLA_TEST_P(ReshapeTest, TransposeAsReshape) {
// Transposes a 0x4 array with XlaBuilder::Transpose.
XLA_TEST_P(ReshapeTest, Transpose0x4) {
XlaBuilder builder(TestName());
- auto input_literal = Literal::CreateFromArray(Array2D<float>(0, 4));
+ auto input_literal = LiteralUtil::CreateFromArray(Array2D<float>(0, 4));
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
Transpose(parameter, {1, 0});
- auto expected_literal = Literal::CreateR2<float>({{}, {}, {}, {}});
+ auto expected_literal = LiteralUtil::CreateR2<float>({{}, {}, {}, {}});
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
@@ -286,14 +286,14 @@ XLA_TEST_P(ReshapeTest, Transpose0x4) {
XLA_TEST_P(ReshapeTest, Transpose4x3) {
XlaBuilder builder(TestName());
auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3);
- auto input_literal = Literal::CreateFromArray(*a4x3);
+ auto input_literal = LiteralUtil::CreateFromArray(*a4x3);
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
Transpose(parameter, {1, 0});
auto expected = ReferenceUtil::TransposeArray2D(*a4x3);
- auto expected_literal = Literal::CreateFromArray(*expected);
+ auto expected_literal = LiteralUtil::CreateFromArray(*expected);
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
@@ -302,26 +302,27 @@ XLA_TEST_P(ReshapeTest, Transpose4x3) {
// rearrangement of the originals (split), but no reordering (no shuffle).
XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffleZeroElements) {
XlaBuilder builder(TestName());
- auto input_literal = Literal::CreateFromArray(Array2D<float>(6, 0));
+ auto input_literal = LiteralUtil::CreateFromArray(Array2D<float>(6, 0));
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1},
/*new_sizes=*/{2, 3, 0, 0});
- auto expected_literal = Literal::CreateFromArray(Array4D<float>(2, 3, 0, 0));
+ auto expected_literal =
+ LiteralUtil::CreateFromArray(Array4D<float>(2, 3, 0, 0));
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
XLA_TEST_P(ReshapeTest, ReshapeR4ToR2ZeroElements) {
XlaBuilder builder(TestName());
- auto input_literal = Literal::CreateFromArray(Array4D<float>(2, 3, 4, 0));
+ auto input_literal = LiteralUtil::CreateFromArray(Array4D<float>(2, 3, 4, 0));
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3},
/*new_sizes=*/{24, 0});
- auto expected_literal = Literal::CreateFromArray(Array2D<float>(24, 0));
+ auto expected_literal = LiteralUtil::CreateFromArray(Array2D<float>(24, 0));
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
@@ -331,7 +332,7 @@ XLA_TEST_P(ReshapeTest, ReshapeR4ToR2ZeroElements) {
XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffle) {
XlaBuilder builder(TestName());
auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3);
- auto input_literal = Literal::CreateFromArray(*a4x3);
+ auto input_literal = LiteralUtil::CreateFromArray(*a4x3);
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
@@ -339,20 +340,20 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffle) {
/*new_sizes=*/{2, 6});
auto expected = MakeLinspaceArray2D(1.0f, 12.0f, 2, 6);
- auto expected_literal = Literal::CreateFromArray(*expected);
+ auto expected_literal = LiteralUtil::CreateFromArray(*expected);
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffleZeroElements) {
XlaBuilder builder(TestName());
- auto input_literal = Literal::CreateFromArray(Array2D<float>(0, 6));
+ auto input_literal = LiteralUtil::CreateFromArray(Array2D<float>(0, 6));
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0},
/*new_sizes=*/{3, 0});
- auto expected_literal = Literal::CreateFromArray(Array2D<float>(3, 0));
+ auto expected_literal = LiteralUtil::CreateFromArray(Array2D<float>(3, 0));
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
@@ -362,7 +363,7 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffleZeroElements) {
XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffle) {
XlaBuilder builder(TestName());
auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3);
- auto input_literal = Literal::CreateFromArray(*a4x3);
+ auto input_literal = LiteralUtil::CreateFromArray(*a4x3);
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
@@ -370,7 +371,7 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffle) {
/*new_sizes=*/{2, 6});
Array2D<float> expected({{1.0f, 4.0f, 7.0f, 10.0f, 2.0f, 5.0f},
{8.0f, 11.0f, 3.0f, 6.0f, 9.0f, 12.0f}});
- auto expected_literal = Literal::CreateFromArray(expected);
+ auto expected_literal = LiteralUtil::CreateFromArray(expected);
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
@@ -388,13 +389,13 @@ static Array3D<float> ArrayForDocR3Tests() {
XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_012) {
XlaBuilder builder(TestName());
- auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests());
+ auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests());
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2},
/*new_sizes=*/{24});
- auto expected_literal = Literal::CreateR1<float>(
+ auto expected_literal = LiteralUtil::CreateR1<float>(
{10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 26, 27,
30, 31, 32, 35, 36, 37, 40, 41, 42, 45, 46, 47});
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
@@ -403,33 +404,33 @@ XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_012) {
XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) {
XlaBuilder builder(TestName());
- auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests());
+ auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests());
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2},
/*new_sizes=*/{8, 3});
- auto expected_literal = Literal::CreateR2<float>({{10, 11, 12},
- {15, 16, 17},
- {20, 21, 22},
- {25, 26, 27},
- {30, 31, 32},
- {35, 36, 37},
- {40, 41, 42},
- {45, 46, 47}});
+ auto expected_literal = LiteralUtil::CreateR2<float>({{10, 11, 12},
+ {15, 16, 17},
+ {20, 21, 22},
+ {25, 26, 27},
+ {30, 31, 32},
+ {35, 36, 37},
+ {40, 41, 42},
+ {45, 46, 47}});
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_120) {
XlaBuilder builder(TestName());
- auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests());
+ auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests());
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0},
/*new_sizes=*/{24});
- auto expected_literal = Literal::CreateR1<float>(
+ auto expected_literal = LiteralUtil::CreateR1<float>(
{10, 20, 30, 40, 11, 21, 31, 41, 12, 22, 32, 42,
15, 25, 35, 45, 16, 26, 36, 46, 17, 27, 37, 47});
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
@@ -438,33 +439,33 @@ XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_120) {
XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) {
XlaBuilder builder(TestName());
- auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests());
+ auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests());
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0},
/*new_sizes=*/{8, 3});
- auto expected_literal = Literal::CreateR2<float>({{10, 20, 30},
- {40, 11, 21},
- {31, 41, 12},
- {22, 32, 42},
- {15, 25, 35},
- {45, 16, 26},
- {36, 46, 17},
- {27, 37, 47}});
+ auto expected_literal = LiteralUtil::CreateR2<float>({{10, 20, 30},
+ {40, 11, 21},
+ {31, 41, 12},
+ {22, 32, 42},
+ {15, 25, 35},
+ {45, 16, 26},
+ {36, 46, 17},
+ {27, 37, 47}});
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
XLA_TEST_P(ReshapeTest, DocR3_R3_Collapse_120_Refine_262) {
XlaBuilder builder(TestName());
- auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests());
+ auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests());
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0},
/*new_sizes=*/{2, 6, 2});
- auto expected_literal = Literal::CreateR3<float>(
+ auto expected_literal = LiteralUtil::CreateR3<float>(
{{{10, 20}, {30, 40}, {11, 21}, {31, 41}, {12, 22}, {32, 42}},
{{15, 25}, {35, 45}, {16, 26}, {36, 46}, {17, 27}, {37, 47}}});
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
@@ -491,12 +492,12 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapse) {
Array4D<float> t2x2x2x3(2, 2, 2, 3);
auto filler2x3 = MakeLinspaceArray2D(1.0f, 6.0f, 2, 3);
t2x2x2x3.FillWithYX(*filler2x3);
- auto input_literal = Literal::CreateFromArray(t2x2x2x3);
+ auto input_literal = LiteralUtil::CreateFromArray(t2x2x2x3);
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
Collapse(/*operand=*/parameter, /*dimensions=*/{1, 2, 3});
- auto expected_literal = Literal::CreateR2<float>(
+ auto expected_literal = LiteralUtil::CreateR2<float>(
{{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f},
{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
6.0f}});
@@ -516,7 +517,7 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapseDesugared) {
t(1, 0, 0, 1) = 5;
t(1, 0, 1, 0) = 6;
t(1, 0, 1, 1) = 7;
- auto input_literal = Literal::CreateFromArray(t);
+ auto input_literal = LiteralUtil::CreateFromArray(t);
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
@@ -524,7 +525,7 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapseDesugared) {
/*new_sizes=*/{2, 4});
auto expected_literal =
- Literal::CreateR2<float>({{0, 1, 2, 3}, {4, 5, 6, 7}});
+ LiteralUtil::CreateR2<float>({{0, 1, 2, 3}, {4, 5, 6, 7}});
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
@@ -545,7 +546,7 @@ XLA_TEST_P(ReshapeTest, ToScalar) {
&b, &parameter);
Reshape(parameter, dimensions, {});
- auto expected_literal = Literal::CreateR0<float>(83.0f);
+ auto expected_literal = LiteralUtil::CreateR0<float>(83.0f);
ComputeAndCompareLiteral(&b, *expected_literal, {input.get()},
zero_error_spec_);
}
@@ -553,7 +554,7 @@ XLA_TEST_P(ReshapeTest, ToScalar) {
XLA_TEST_P(ReshapeTest, BadDimensions) {
XlaBuilder b(TestName());
- auto input_literal = Literal::CreateR1<float>({1.0f});
+ auto input_literal = LiteralUtil::CreateR1<float>({1.0f});
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b,
&parameter);
@@ -565,7 +566,7 @@ XLA_TEST_P(ReshapeTest, BadDimensions) {
XLA_TEST_P(ReshapeTest, BadNewSizes) {
XlaBuilder b(TestName());
- auto input_literal = Literal::CreateR1<float>({1.0f, 2.0f});
+ auto input_literal = LiteralUtil::CreateR1<float>({1.0f, 2.0f});
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b,
&parameter);
@@ -577,7 +578,8 @@ XLA_TEST_P(ReshapeTest, BadNewSizes) {
XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) {
XlaBuilder builder(TestName());
// clang-format off
- auto input_literal = Literal::CreateR4FromArray4DWithLayout(Array4D<float>{
+ auto input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ Array4D<float>{
{
{
{0, 1},
@@ -622,16 +624,16 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) {
->ExecuteAndTransfer(computation, {input.get()}, &execution_options)
.ConsumeValueOrDie();
std::unique_ptr<Literal> expected =
- Literal::CreateR2FromArray2D<float>(expected_array);
+ LiteralUtil::CreateR2FromArray2D<float>(expected_array);
if (use_bfloat16()) {
- expected = Literal::ConvertF32ToBF16(*expected);
+ expected = LiteralUtil::ConvertF32ToBF16(*expected);
}
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *actual));
}
XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> input_literal = Literal::CreateR2<float>({
+ std::unique_ptr<Literal> input_literal = LiteralUtil::CreateR2<float>({
{0, 1, 2, 3, 4, 5, 6, 7},
{100, 101, 102, 103, 104, 105, 106, 107},
{200, 201, 202, 203, 204, 205, 206, 207},
@@ -642,7 +644,7 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) {
Reshape(parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{3, 2, 1, 4});
// clang-format off
- auto expected_literal = Literal::CreateR4<float>({
+ auto expected_literal = LiteralUtil::CreateR4<float>({
{{{0, 1, 2, 3}},
{{4, 5, 6, 7}}},
{{{100, 101, 102, 103}},
@@ -658,7 +660,7 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) {
// Tests R2->R4 reshape with the reshape dimensions {1, 0}.
XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> input_literal = Literal::CreateR2<float>({
+ std::unique_ptr<Literal> input_literal = LiteralUtil::CreateR2<float>({
{0, 1, 2, 3, 4, 5, 6, 7},
{100, 101, 102, 103, 104, 105, 106, 107},
{200, 201, 202, 203, 204, 205, 206, 207},
@@ -669,7 +671,7 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) {
Reshape(parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 2, 1, 4});
// clang-format off
- auto expected_literal = Literal::CreateR4<float>({
+ auto expected_literal = LiteralUtil::CreateR4<float>({
{{{0, 100, 200, 1}},
{{101, 201, 2, 102}}},
{{{202, 3, 103, 203}},
@@ -691,7 +693,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) {
[&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
float* cell) { *cell = distribution(rng); });
std::unique_ptr<Literal> input_literal =
- Literal::CreateR4FromArray4DWithLayout(
+ LiteralUtil::CreateR4FromArray4DWithLayout(
input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaOp parameter;
auto input_data = CreateParameterAndTransferLiteral(
@@ -699,7 +701,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) {
Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1});
std::unique_ptr<Literal> expected =
- Literal::ReshapeSlice({2, 1}, {1, 0}, *input_literal);
+ LiteralUtil::ReshapeSlice({2, 1}, {1, 0}, *input_literal);
ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
zero_error_spec_);
}
@@ -713,7 +715,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) {
[&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
float* cell) { *cell = distribution(rng); });
std::unique_ptr<Literal> input_literal =
- Literal::CreateR4FromArray4DWithLayout(
+ LiteralUtil::CreateR4FromArray4DWithLayout(
input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaOp parameter;
auto input_data = CreateParameterAndTransferLiteral(
@@ -721,7 +723,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) {
Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2});
std::unique_ptr<Literal> expected =
- Literal::ReshapeSlice({4, 2}, {1, 0}, *input_literal);
+ LiteralUtil::ReshapeSlice({4, 2}, {1, 0}, *input_literal);
ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
zero_error_spec_);
}
@@ -736,7 +738,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) {
[&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
float* cell) { *cell = distribution(rng); });
std::unique_ptr<Literal> input_literal =
- Literal::CreateR4FromArray4DWithLayout(
+ LiteralUtil::CreateR4FromArray4DWithLayout(
input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaOp parameter;
auto input_data = CreateParameterAndTransferLiteral(
@@ -749,7 +751,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) {
expected_array(indices[0], indices[2] * 30 + indices[1] * 3 + indices[3]) =
*cell;
});
- auto expected = Literal::CreateR2FromArray2D(expected_array);
+ auto expected = LiteralUtil::CreateR2FromArray2D(expected_array);
ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
zero_error_spec_);
}
@@ -763,7 +765,7 @@ XLA_TEST_P(ReshapeTest, NoopReshape) {
[&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
float* cell) { *cell = distribution(rng); });
std::unique_ptr<Literal> input_literal =
- Literal::CreateR4FromArray4DWithLayout(
+ LiteralUtil::CreateR4FromArray4DWithLayout(
input_array, LayoutUtil::MakeLayout({1, 2, 3, 0}));
XlaOp parameter;
auto input_data = CreateParameterAndTransferLiteral(
@@ -785,7 +787,7 @@ XLA_TEST_P(ReshapeTest, NoopReshape) {
// Since the reshape is a no-op, verify that it does not change the underlying
// data.
if (use_bfloat16()) {
- auto expected = Literal::ConvertF32ToBF16(*input_literal);
+ auto expected = LiteralUtil::ConvertF32ToBF16(*input_literal);
EXPECT_EQ(expected->data<bfloat16>(), output_literal->data<bfloat16>());
} else {
EXPECT_EQ(input_literal->data<float>(), output_literal->data<float>());
@@ -794,7 +796,7 @@ XLA_TEST_P(ReshapeTest, NoopReshape) {
XLA_TEST_P(ReshapeTest, R4ToR4Reshape_Trivial) {
XlaBuilder builder(TestName());
- auto literal_1x2x3x4 = Literal::CreateR4<float>(
+ auto literal_1x2x3x4 = LiteralUtil::CreateR4<float>(
{{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
{{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}});
@@ -808,7 +810,7 @@ XLA_TEST_P(ReshapeTest, R4ToR4Reshape_Trivial) {
}
XLA_TEST_P(ReshapeTest, R4ToR4Reshape) {
- auto literal_1x2x3x4 = Literal::CreateR4<float>(
+ auto literal_1x2x3x4 = LiteralUtil::CreateR4<float>(
{{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
{{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}});
@@ -820,7 +822,7 @@ XLA_TEST_P(ReshapeTest, R4ToR4Reshape) {
/*new_sizes=*/{2, 4, 3, 1});
// clang-format off
- auto expected_2x4x3x1 = Literal::CreateR4<float>(
+ auto expected_2x4x3x1 = LiteralUtil::CreateR4<float>(
{{{{1}, {5}, {9}},
{{2}, {6}, {10}},
{{3}, {7}, {11}},
@@ -844,7 +846,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) {
[&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
float* cell) { *cell = distribution(rng); });
std::unique_ptr<Literal> input_literal =
- Literal::CreateR4FromArray4DWithLayout(
+ LiteralUtil::CreateR4FromArray4DWithLayout(
input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaBuilder builder(TestName());
XlaOp parameter;
@@ -854,7 +856,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) {
/*new_sizes=*/new_bounds);
std::unique_ptr<Literal> expected =
- Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
+ LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
// Specify the requested output shape explicitly to ensure that this reshape
@@ -873,7 +875,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) {
[&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
float* cell) { *cell = distribution(rng); });
std::unique_ptr<Literal> input_literal =
- Literal::CreateR4FromArray4DWithLayout(
+ LiteralUtil::CreateR4FromArray4DWithLayout(
input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaBuilder builder(TestName());
XlaOp parameter;
@@ -883,7 +885,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) {
/*new_sizes=*/new_bounds);
std::unique_ptr<Literal> expected =
- Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
+ LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
// Specify the requested output shape explicitly to ensure that this reshape
@@ -902,7 +904,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) {
[&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
float* cell) { *cell = distribution(rng); });
std::unique_ptr<Literal> input_literal =
- Literal::CreateR4FromArray4DWithLayout(
+ LiteralUtil::CreateR4FromArray4DWithLayout(
input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaBuilder builder(TestName());
XlaOp parameter;
@@ -912,7 +914,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) {
/*new_sizes=*/new_bounds);
std::unique_ptr<Literal> expected =
- Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
+ LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
// Specify the requested output shape explicitly to ensure that this reshape
@@ -932,7 +934,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) {
[&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
float* cell) { *cell = distribution(rng); });
std::unique_ptr<Literal> input_literal =
- Literal::CreateR4FromArray4DWithLayout(
+ LiteralUtil::CreateR4FromArray4DWithLayout(
input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaBuilder builder(TestName());
XlaOp parameter;
@@ -942,7 +944,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) {
/*new_sizes=*/new_bounds);
std::unique_ptr<Literal> expected =
- Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
+ LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
// Specify the requested output shape explicitly to ensure that this reshape
@@ -961,7 +963,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) {
[&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
float* cell) { *cell = distribution(rng); });
std::unique_ptr<Literal> input_literal =
- Literal::CreateR4FromArray4DWithLayout(
+ LiteralUtil::CreateR4FromArray4DWithLayout(
input, LayoutUtil::MakeLayout({0, 1, 2, 3}));
XlaBuilder builder(TestName());
XlaOp parameter;
@@ -971,7 +973,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) {
/*new_sizes=*/new_bounds);
std::unique_ptr<Literal> expected =
- Literal::ReshapeSlice(new_bounds, {1, 0, 2, 3}, *input_literal)
+ LiteralUtil::ReshapeSlice(new_bounds, {1, 0, 2, 3}, *input_literal)
->Relayout(input_literal->shape().layout());
// Specify the requested output shape explicitly to ensure that this reshape
diff --git a/tensorflow/compiler/xla/tests/reverse_test.cc b/tensorflow/compiler/xla/tests/reverse_test.cc
index 662bc42224..23f0d26d93 100644
--- a/tensorflow/compiler/xla/tests/reverse_test.cc
+++ b/tensorflow/compiler/xla/tests/reverse_test.cc
@@ -82,7 +82,7 @@ TEST_P(FloatReverseTest, Reverses) {
std::vector<float> input_vector(
ShapeUtil::ElementsIn(ShapeUtil::MakeShape(F32, spec.input_dims)));
std::iota(input_vector.begin(), input_vector.end(), 0.0);
- auto r1_literal = Literal::CreateR1<float>(input_vector);
+ auto r1_literal = LiteralUtil::CreateR1<float>(input_vector);
auto input_literal = r1_literal->Reshape(spec.input_dims).ConsumeValueOrDie();
XlaBuilder builder(TestName());
diff --git a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc
index 7cfca781ac..a620fe1908 100644
--- a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc
+++ b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/packed_literal_reader.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
diff --git a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc
index f334a8c131..a8193c2eac 100644
--- a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc
+++ b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc
@@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
@@ -46,61 +46,62 @@ class RoundTripTransferTest : public ClientLibraryTestBase {
};
TEST_F(RoundTripTransferTest, R0S32) {
- RoundTripTest(*Literal::CreateR0<int32>(42));
+ RoundTripTest(*LiteralUtil::CreateR0<int32>(42));
}
TEST_F(RoundTripTransferTest, R0F32) {
- RoundTripTest(*Literal::CreateR0<float>(42.0));
+ RoundTripTest(*LiteralUtil::CreateR0<float>(42.0));
}
TEST_F(RoundTripTransferTest, R1F32_Len0) {
- RoundTripTest(*Literal::CreateR1<float>({}));
+ RoundTripTest(*LiteralUtil::CreateR1<float>({}));
}
TEST_F(RoundTripTransferTest, R1F32_Len2) {
- RoundTripTest(*Literal::CreateR1<float>({42.0, 64.0}));
+ RoundTripTest(*LiteralUtil::CreateR1<float>({42.0, 64.0}));
}
TEST_F(RoundTripTransferTest, R1F32_Len256) {
std::vector<float> values(256);
std::iota(values.begin(), values.end(), 1.0);
- RoundTripTest(*Literal::CreateR1<float>(values));
+ RoundTripTest(*LiteralUtil::CreateR1<float>(values));
}
TEST_F(RoundTripTransferTest, R1F32_Len1024) {
std::vector<float> values(1024);
std::iota(values.begin(), values.end(), 1.0);
- RoundTripTest(*Literal::CreateR1<float>(values));
+ RoundTripTest(*LiteralUtil::CreateR1<float>(values));
}
TEST_F(RoundTripTransferTest, R1F32_Len1025) {
std::vector<float> values(1025);
std::iota(values.begin(), values.end(), 1.0);
- RoundTripTest(*Literal::CreateR1<float>(values));
+ RoundTripTest(*LiteralUtil::CreateR1<float>(values));
}
TEST_F(RoundTripTransferTest, R1F32_Len4096) {
std::vector<float> values(4096);
std::iota(values.begin(), values.end(), 1.0);
- RoundTripTest(*Literal::CreateR1<float>(values));
+ RoundTripTest(*LiteralUtil::CreateR1<float>(values));
}
TEST_F(RoundTripTransferTest, R2F32_Len10x0) {
- RoundTripTest(*Literal::CreateR2FromArray2D<float>(Array2D<float>(10, 0)));
+ RoundTripTest(
+ *LiteralUtil::CreateR2FromArray2D<float>(Array2D<float>(10, 0)));
}
TEST_F(RoundTripTransferTest, R2F32_Len2x2) {
- RoundTripTest(*Literal::CreateR2<float>({{42.0, 64.0}, {77.0, 88.0}}));
+ RoundTripTest(*LiteralUtil::CreateR2<float>({{42.0, 64.0}, {77.0, 88.0}}));
}
TEST_F(RoundTripTransferTest, R3F32) {
RoundTripTest(
- *Literal::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
- {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}));
+ *LiteralUtil::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
+ {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}));
}
TEST_F(RoundTripTransferTest, R4F32) {
- RoundTripTest(*Literal::CreateR4<float>({{
+ RoundTripTest(*LiteralUtil::CreateR4<float>({{
{{10, 11, 12, 13}, {14, 15, 16, 17}},
{{18, 19, 20, 21}, {22, 23, 24, 25}},
{{26, 27, 28, 29}, {30, 31, 32, 33}},
@@ -108,33 +109,36 @@ TEST_F(RoundTripTransferTest, R4F32) {
}
TEST_F(RoundTripTransferTest, EmptyTuple) {
- RoundTripTest(*Literal::MakeTuple({}));
+ RoundTripTest(*LiteralUtil::MakeTuple({}));
}
TEST_F(RoundTripTransferTest, TupleOfR1F32) {
- RoundTripTest(*Literal::MakeTuple({Literal::CreateR1<float>({1, 2}).get(),
- Literal::CreateR1<float>({3, 4}).get()}));
+ RoundTripTest(
+ *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1, 2}).get(),
+ LiteralUtil::CreateR1<float>({3, 4}).get()}));
}
TEST_F(RoundTripTransferTest, TupleOfR1F32_Len0_Len2) {
- RoundTripTest(*Literal::MakeTuple({Literal::CreateR1<float>({}).get(),
- Literal::CreateR1<float>({3, 4}).get()}));
+ RoundTripTest(
+ *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({}).get(),
+ LiteralUtil::CreateR1<float>({3, 4}).get()}));
}
TEST_F(RoundTripTransferTest, TupleOfR0F32AndR1S32) {
- RoundTripTest(*Literal::MakeTuple({Literal::CreateR0<float>(1.0).get(),
- Literal::CreateR1<int>({2, 3}).get()}));
+ RoundTripTest(
+ *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(1.0).get(),
+ LiteralUtil::CreateR1<int>({2, 3}).get()}));
}
// Below two tests are added to identify the cost of large data transfers.
TEST_F(RoundTripTransferTest, R2F32_Large) {
- RoundTripTest(*Literal::CreateR2F32Linspace(-1.0f, 1.0f, 512, 512));
+ RoundTripTest(*LiteralUtil::CreateR2F32Linspace(-1.0f, 1.0f, 512, 512));
}
TEST_F(RoundTripTransferTest, R4F32_Large) {
Array4D<float> array4d(2, 2, 256, 256);
array4d.FillWithMultiples(1.0f);
- RoundTripTest(*Literal::CreateR4FromArray4D<float>(array4d));
+ RoundTripTest(*LiteralUtil::CreateR4FromArray4D<float>(array4d));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc
index 3afd8c8fc8..3b603c0d31 100644
--- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc
+++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -162,7 +163,7 @@ XLA_TEST_F(ScalarComputationsTest, CastS64ToF32) {
ConvertElementType(a, F32);
int64 value = 3LL << 35;
- std::unique_ptr<Literal> a_literal = Literal::CreateR0<int64>(value);
+ std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR0<int64>(value);
std::unique_ptr<GlobalData> a_data =
client_->TransferToServer(*a_literal).ConsumeValueOrDie();
ComputeAndCompareR0<float>(&builder, static_cast<float>(value),
@@ -226,9 +227,9 @@ XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsS32) {
XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> a_literal = Literal::CreateR0<float>(2.1f);
- std::unique_ptr<Literal> b_literal = Literal::CreateR0<float>(5.5f);
- std::unique_ptr<Literal> c_literal = Literal::CreateR0<float>(0.5f);
+ std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR0<float>(2.1f);
+ std::unique_ptr<Literal> b_literal = LiteralUtil::CreateR0<float>(5.5f);
+ std::unique_ptr<Literal> c_literal = LiteralUtil::CreateR0<float>(0.5f);
std::unique_ptr<GlobalData> a_data =
client_->TransferToServer(*a_literal).ConsumeValueOrDie();
@@ -375,8 +376,8 @@ XLA_TEST_F(ScalarComputationsTest, DivU32s) {
for (uint32 divisor : vals) {
if (divisor != 0) {
for (uint32 dividend : vals) {
- auto dividend_literal = Literal::CreateR0<uint32>(dividend);
- auto divisor_literal = Literal::CreateR0<uint32>(divisor);
+ auto dividend_literal = LiteralUtil::CreateR0<uint32>(dividend);
+ auto divisor_literal = LiteralUtil::CreateR0<uint32>(divisor);
TF_ASSERT_OK_AND_ASSIGN(auto dividend_data,
client_->TransferToServer(*dividend_literal));
TF_ASSERT_OK_AND_ASSIGN(auto divisor_data,
@@ -387,7 +388,8 @@ XLA_TEST_F(ScalarComputationsTest, DivU32s) {
{dividend_data.get(), divisor_data.get()},
&execution_options_)
.ConsumeValueOrDie();
- auto expected_literal = Literal::CreateR0<uint32>(dividend / divisor);
+ auto expected_literal =
+ LiteralUtil::CreateR0<uint32>(dividend / divisor);
EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
}
@@ -416,8 +418,8 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) {
for (uint32 divisor : vals) {
if (divisor != 0) {
for (uint32 dividend : vals) {
- auto dividend_literal = Literal::CreateR0<uint32>(dividend);
- auto divisor_literal = Literal::CreateR0<uint32>(divisor);
+ auto dividend_literal = LiteralUtil::CreateR0<uint32>(dividend);
+ auto divisor_literal = LiteralUtil::CreateR0<uint32>(divisor);
TF_ASSERT_OK_AND_ASSIGN(auto dividend_data,
client_->TransferToServer(*dividend_literal));
TF_ASSERT_OK_AND_ASSIGN(auto divisor_data,
@@ -428,7 +430,8 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) {
{dividend_data.get(), divisor_data.get()},
&execution_options_)
.ConsumeValueOrDie();
- auto expected_literal = Literal::CreateR0<uint32>(dividend % divisor);
+ auto expected_literal =
+ LiteralUtil::CreateR0<uint32>(dividend % divisor);
EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
}
@@ -440,7 +443,7 @@ XLA_TEST_F(ScalarComputationsTest, RemainderTwoScalarsNonConstDividendS32) {
auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(S32, {}), "x");
Rem(x, ConstantR0<int32>(&builder, 80000));
- std::unique_ptr<Literal> literal = Literal::CreateR0<int32>(87919);
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR0<int32>(87919);
TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(*literal));
ComputeAndCompareR0<int32>(&builder, 7919, {input_data.get()});
}
diff --git a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc
index 0a173fbbbd..b1f1e69d3c 100644
--- a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc
+++ b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc
@@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc
index 3e5c01d6d4..48138e7b07 100644
--- a/tensorflow/compiler/xla/tests/slice_test.cc
+++ b/tensorflow/compiler/xla/tests/slice_test.cc
@@ -170,7 +170,7 @@ XLA_TEST_F(SliceTest, StridedSliceR4WithOutputLayout) {
values.FillRandom(3.14f);
auto expected = ReferenceUtil::Slice4D(values, {{0, 0, 0, 0}}, {{2, 4, 6, 8}},
/*strides=*/{{1, 1, 2, 1}});
- auto expected_literal = Literal::CreateR4FromArray4DWithLayout(
+ auto expected_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
*expected, LayoutUtil::MakeLayout({0, 1, 2, 3}));
XlaBuilder builder(TestName());
auto original = ConstantR4FromArray4D(&builder, values);
@@ -197,7 +197,7 @@ class SliceR1Test : public ClientLibraryTestBase,
// vector<bool>.
tensorflow::gtl::InlinedVector<NativeT, 1> input(spec.input_dim0);
std::iota(input.begin(), input.end(), NativeT());
- auto literal = Literal::CreateR1<NativeT>(input);
+ auto literal = LiteralUtil::CreateR1<NativeT>(input);
XlaBuilder builder(TestName());
auto original = Parameter(&builder, 0, literal->shape(), "p0");
@@ -368,7 +368,7 @@ XLA_TEST_P(SliceR2Test, DoIt) {
const R2Spec& spec = GetParam();
Array2D<int32> input(spec.input_dim0, spec.input_dim1);
input.FillUnique();
- auto literal = Literal::CreateR2FromArray2DWithLayout(
+ auto literal = LiteralUtil::CreateR2FromArray2DWithLayout(
input, LayoutUtil::MakeLayout(spec.layout));
XlaBuilder builder(TestName());
@@ -463,7 +463,7 @@ class SliceR4Test : public ClientLibraryTestBase,
auto expected = ReferenceUtil::Slice4D(
values, spec.slice_starts, spec.slice_limits, spec.slice_strides);
XlaBuilder builder(TestName());
- auto literal = Literal::CreateR4FromArray4DWithLayout(
+ auto literal = LiteralUtil::CreateR4FromArray4DWithLayout(
values, LayoutUtil::MakeLayout(spec.input_layout));
auto parameter = Parameter(&builder, 0, literal->shape(), "p0");
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> arg,
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index 20c7c30878..2647937013 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/tests/test_utils.h"
+#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
@@ -110,7 +111,7 @@ StatusOr<std::unique_ptr<Literal>> MakeFakeLiteralInternal(
MakeFakeLiteralInternal(element_shape, engine));
elements.push_back(std::move(element));
}
- return Literal::MakeTupleOwned(std::move(elements));
+ return LiteralUtil::MakeTupleOwned(std::move(elements));
}
if (engine == nullptr) {
return Literal::CreateFromShape(shape);
@@ -220,7 +221,7 @@ std::unique_ptr<Literal> MakeRandomNonwrappingSliceIndex(
start_indices[i] = generator(*engine);
}
}
- return Literal::CreateR1<int32>(start_indices);
+ return LiteralUtil::CreateR1<int32>(start_indices);
}
// Use dataflow analysis on each parameter to see if there are uses that would
@@ -318,9 +319,9 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
} else if (needs_constant != nullptr) {
switch (constant_type) {
case ConstantType::kZero:
- return Literal::Zero(param.shape().element_type()).CloneToUnique();
+ return LiteralUtil::Zero(param.shape().element_type()).CloneToUnique();
case ConstantType::kOne:
- return Literal::One(param.shape().element_type()).CloneToUnique();
+ return LiteralUtil::One(param.shape().element_type()).CloneToUnique();
case ConstantType::kUnknown:
// We want the identity element for the computation, but we don't really
// know what it is - so any value we generate will be just as wrong.
diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h
index a8689f6498..e59f215a9a 100644
--- a/tensorflow/compiler/xla/tests/test_utils.h
+++ b/tensorflow/compiler/xla/tests/test_utils.h
@@ -21,7 +21,7 @@ limitations under the License.
#include <random>
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc
index e9008fa48a..2bdbd08309 100644
--- a/tensorflow/compiler/xla/tests/token_hlo_test.cc
+++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc
@@ -31,21 +31,21 @@ class TokenHloTest : public HloTestBase {};
XLA_TEST_F(TokenHloTest, SingleTokenInstruction) {
std::unique_ptr<HloModule> module = CreateNewModule();
auto builder = HloComputation::Builder(TestName());
- builder.AddInstruction(HloInstruction::CreateAfterAll({}));
+ builder.AddInstruction(HloInstruction::CreateToken());
module->AddEntryComputation(builder.Build());
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
Execute(std::move(module), {}));
- EXPECT_TRUE(LiteralTestUtil::Equal(*result, *Literal::CreateToken()));
+ EXPECT_TRUE(LiteralTestUtil::Equal(*result, *LiteralUtil::CreateToken()));
}
XLA_TEST_F(TokenHloTest, TokenTree) {
std::unique_ptr<HloModule> module = CreateNewModule();
auto builder = HloComputation::Builder(TestName());
- auto token0 = builder.AddInstruction(HloInstruction::CreateAfterAll({}));
- auto token1 = builder.AddInstruction(HloInstruction::CreateAfterAll({}));
- auto token2 = builder.AddInstruction(HloInstruction::CreateAfterAll({}));
+ auto token0 = builder.AddInstruction(HloInstruction::CreateToken());
+ auto token1 = builder.AddInstruction(HloInstruction::CreateToken());
+ auto token2 = builder.AddInstruction(HloInstruction::CreateToken());
builder.AddInstruction(
HloInstruction::CreateAfterAll({token0, token0, token1, token2}));
@@ -53,7 +53,7 @@ XLA_TEST_F(TokenHloTest, TokenTree) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
Execute(std::move(module), {}));
- EXPECT_TRUE(LiteralTestUtil::Equal(*result, *Literal::CreateToken()));
+ EXPECT_TRUE(LiteralTestUtil::Equal(*result, *LiteralUtil::CreateToken()));
}
XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) {
@@ -64,7 +64,7 @@ XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) {
builder.AddInstruction(
HloInstruction::CreateParameter(1, ShapeUtil::MakeTokenShape(), "p1"));
builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(42)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(42)));
module->AddEntryComputation(builder.Build());
Status status = HloVerifier().Run(module.get()).status();
@@ -98,7 +98,7 @@ XLA_TEST_F(TokenHloTest, InvalidOperandToTokenInstruction) {
HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0"));
builder.AddInstruction(HloInstruction::CreateAfterAll({param}));
builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(123)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(123)));
module->AddEntryComputation(builder.Build());
Status status = HloVerifier().Run(module.get()).status();
@@ -184,7 +184,7 @@ ENTRY %TokenInConditional (param.3: pred[]) -> s32[] {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
HloRunner::CreateModuleFromString(module_string, debug_options));
- auto arg = Literal::CreateR0<bool>(true);
+ auto arg = LiteralUtil::CreateR0<bool>(true);
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
Execute(std::move(module), {arg.get()}));
EXPECT_EQ(42, result->Get<int32>({}));
@@ -195,7 +195,7 @@ ENTRY %TokenInConditional (param.3: pred[]) -> s32[] {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
HloRunner::CreateModuleFromString(module_string, debug_options));
- auto arg = Literal::CreateR0<bool>(false);
+ auto arg = LiteralUtil::CreateR0<bool>(false);
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
Execute(std::move(module), {arg.get()}));
EXPECT_EQ(7, result->Get<int32>({}));
diff --git a/tensorflow/compiler/xla/tests/transfer_manager_test.cc b/tensorflow/compiler/xla/tests/transfer_manager_test.cc
index 86babb58c9..0f86b7f20f 100644
--- a/tensorflow/compiler/xla/tests/transfer_manager_test.cc
+++ b/tensorflow/compiler/xla/tests/transfer_manager_test.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/generic_transfer_manager.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
@@ -68,7 +68,7 @@ class TransferManagerTest : public LocalClientTestBase {
};
XLA_TEST_F(TransferManagerTest, TransferR0U32) {
- std::unique_ptr<Literal> literal = Literal::CreateR0<uint32>(42);
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR0<uint32>(42);
const Shape& shape = literal->shape();
auto device_buffer = AllocateDeviceBuffer(shape);
@@ -84,7 +84,7 @@ XLA_TEST_F(TransferManagerTest, TransferR0U32) {
XLA_TEST_F(TransferManagerTest, TransferR1F32) {
std::unique_ptr<Literal> literal =
- Literal::CreateR1<float>({1.25f, 2.5f, -17.0f, -20.125f});
+ LiteralUtil::CreateR1<float>({1.25f, 2.5f, -17.0f, -20.125f});
const Shape& shape = literal->shape();
auto device_buffer = AllocateDeviceBuffer(shape);
@@ -102,7 +102,7 @@ XLA_TEST_F(TransferManagerTest, TransferR1F32) {
XLA_TEST_F(TransferManagerTest, TransferR1LargeF32) {
std::vector<float> test_vector(1024 * 1024);
std::iota(test_vector.begin(), test_vector.end(), 0);
- std::unique_ptr<Literal> literal = Literal::CreateR1<float>(test_vector);
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<float>(test_vector);
const Shape& shape = literal->shape();
auto device_buffer = AllocateDeviceBuffer(shape);
@@ -118,7 +118,7 @@ XLA_TEST_F(TransferManagerTest, TransferR1LargeF32) {
XLA_TEST_F(TransferManagerTest, TransferR1U8) {
const char* test_string = "0123456789abcdef";
- std::unique_ptr<Literal> literal = Literal::CreateR1U8(test_string);
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR1U8(test_string);
const Shape& shape = literal->shape();
auto device_buffer = AllocateDeviceBuffer(shape);
@@ -134,7 +134,7 @@ XLA_TEST_F(TransferManagerTest, TransferR1U8) {
XLA_TEST_F(TransferManagerTest, TransferR2F32) {
std::unique_ptr<Literal> literal =
- Literal::CreateR2<float>({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}});
+ LiteralUtil::CreateR2<float>({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}});
const Shape& shape = literal->shape();
auto device_buffer = AllocateDeviceBuffer(shape);
@@ -151,7 +151,7 @@ XLA_TEST_F(TransferManagerTest, TransferR2F32) {
XLA_TEST_F(TransferManagerTest,
TransferR2F32AndChangeLayoutTransferringToDevice) {
- std::unique_ptr<Literal> literal = Literal::CreateR2WithLayout<float>(
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR2WithLayout<float>(
{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, LayoutUtil::MakeLayout({0, 1}));
const Shape ondevice_shape =
ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0});
@@ -172,10 +172,10 @@ XLA_TEST_F(TransferManagerTest,
}
XLA_TEST_F(TransferManagerTest, TransferTuple) {
- std::unique_ptr<Literal> literal = Literal::MakeTuple(
- {Literal::CreateR0<float>(123.0f).get(),
- Literal::CreateR2<float>({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(),
- Literal::CreateR1<float>({44.0f, -10.0f, 3333333.3f}).get()});
+ std::unique_ptr<Literal> literal = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR0<float>(123.0f).get(),
+ LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(),
+ LiteralUtil::CreateR1<float>({44.0f, -10.0f, 3333333.3f}).get()});
auto device_buffer = AllocateDeviceBuffer(literal->shape());
// Round trip literal through device.
@@ -189,7 +189,7 @@ XLA_TEST_F(TransferManagerTest, TransferTuple) {
}
XLA_TEST_F(TransferManagerTest, TransferEmptyTuple) {
- std::unique_ptr<Literal> literal = Literal::MakeTuple({});
+ std::unique_ptr<Literal> literal = LiteralUtil::MakeTuple({});
auto device_buffer = AllocateDeviceBuffer(literal->shape());
// Round trip literal through device.
@@ -203,13 +203,13 @@ XLA_TEST_F(TransferManagerTest, TransferEmptyTuple) {
}
XLA_TEST_F(TransferManagerTest, TransferNestedTuple) {
- std::unique_ptr<Literal> literal = Literal::MakeTuple(
- {Literal::CreateR0<float>(123.0f).get(),
- Literal::MakeTuple(
- {Literal::CreateR2<float>({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(),
- Literal::CreateR1<float>({44.0f, -10.0f, 3333333.3f}).get()})
+ std::unique_ptr<Literal> literal = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR0<float>(123.0f).get(),
+ LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(),
+ LiteralUtil::CreateR1<float>({44.0f, -10.0f, 3333333.3f}).get()})
.get(),
- Literal::CreateR1<float>({-10.0f, 123.0f}).get()});
+ LiteralUtil::CreateR1<float>({-10.0f, 123.0f}).get()});
auto device_buffer = AllocateDeviceBuffer(literal->shape());
// Round trip literal through device.
@@ -223,7 +223,7 @@ XLA_TEST_F(TransferManagerTest, TransferNestedTuple) {
}
XLA_TEST_F(TransferManagerTest, TransferComplexValue) {
- std::unique_ptr<Literal> literal = Literal::CreateR1<complex64>(
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<complex64>(
{complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)});
auto device_buffer = AllocateDeviceBuffer(literal->shape());
@@ -238,12 +238,12 @@ XLA_TEST_F(TransferManagerTest, TransferComplexValue) {
}
XLA_TEST_F(TransferManagerTest, TransferComplexValueInTuple) {
- std::unique_ptr<Literal> literal = Literal::MakeTuple(
- {Literal::CreateR1<complex64>(
+ std::unique_ptr<Literal> literal = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR1<complex64>(
{complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)})
.get(),
- Literal::CreateR1<int32>({1, 2, 3, 4, 5, 6}).get(),
- Literal::CreateR0<complex64>(complex64(0.3f, -0.4f)).get()});
+ LiteralUtil::CreateR1<int32>({1, 2, 3, 4, 5, 6}).get(),
+ LiteralUtil::CreateR0<complex64>(complex64(0.3f, -0.4f)).get()});
auto device_buffer = AllocateDeviceBuffer(literal->shape());
// Round trip literal through device.
@@ -265,25 +265,25 @@ XLA_TEST_F(TransferManagerTest, TransferTokenFromDevice) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Literal> result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
- EXPECT_TRUE(LiteralTestUtil::Equal(*Literal::CreateToken(), *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(*LiteralUtil::CreateToken(), *result));
}
XLA_TEST_F(TransferManagerTest, MultiStreamRoundTripSoak) {
const int64 kIterationCount = 5000;
- std::unique_ptr<Literal> literal1 = Literal::MakeTuple(
- {Literal::CreateR0<float>(123.0f).get(),
- Literal::MakeTuple(
- {Literal::CreateR2<float>({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(),
- Literal::CreateR1<float>({44.0f, -10.0f, 3333333.3f}).get()})
+ std::unique_ptr<Literal> literal1 = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR0<float>(123.0f).get(),
+ LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(),
+ LiteralUtil::CreateR1<float>({44.0f, -10.0f, 3333333.3f}).get()})
.get(),
- Literal::CreateR1<float>({-10.0f, 123.0f}).get()});
- std::unique_ptr<Literal> literal2 = Literal::MakeTuple(
- {Literal::CreateR0<float>(456.0f).get(),
- Literal::MakeTuple(
- {Literal::CreateR2<float>({{5.0f, 7.0f}, {9.0f, 4.0f}}).get(),
- Literal::CreateR1<float>({44.0f, -11.0f, 3333333.3f}).get()})
+ LiteralUtil::CreateR1<float>({-10.0f, 123.0f}).get()});
+ std::unique_ptr<Literal> literal2 = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR0<float>(456.0f).get(),
+ LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR2<float>({{5.0f, 7.0f}, {9.0f, 4.0f}}).get(),
+ LiteralUtil::CreateR1<float>({44.0f, -11.0f, 3333333.3f}).get()})
.get(),
- Literal::CreateR1<float>({-98.0f, 153.0f}).get()});
+ LiteralUtil::CreateR1<float>({-98.0f, 153.0f}).get()});
auto device_buffer1 = AllocateDeviceBuffer(literal1->shape());
auto device_buffer2 = AllocateDeviceBuffer(literal2->shape());
@@ -325,10 +325,10 @@ class TransferDeviceToHostBenchmark : public TransferManagerTest {
std::vector<std::unique_ptr<Literal>> tuple_elements;
for (int i = 0; i < num_tuple_elements; ++i) {
tuple_elements.push_back(
- Literal::CreateR2F32Linspace(0.0f, 1.0f, array_size, array_size));
+ LiteralUtil::CreateR2F32Linspace(0.0f, 1.0f, array_size, array_size));
}
std::unique_ptr<Literal> literal =
- Literal::MakeTupleOwned(std::move(tuple_elements));
+ LiteralUtil::MakeTupleOwned(std::move(tuple_elements));
auto device_buffer = AllocateDeviceBuffer(literal->shape());
TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
device_buffer));
@@ -357,10 +357,10 @@ class TransferHostToDeviceBenchmark : public TransferManagerTest {
std::vector<std::unique_ptr<Literal>> tuple_elements;
for (int i = 0; i < num_tuple_elements; ++i) {
tuple_elements.push_back(
- Literal::CreateR2F32Linspace(0.0f, 1.0f, array_size, array_size));
+ LiteralUtil::CreateR2F32Linspace(0.0f, 1.0f, array_size, array_size));
}
std::unique_ptr<Literal> literal =
- Literal::MakeTupleOwned(std::move(tuple_elements));
+ LiteralUtil::MakeTupleOwned(std::move(tuple_elements));
auto device_buffer = AllocateDeviceBuffer(literal->shape());
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc
index ec11508891..bf86c5dfb6 100644
--- a/tensorflow/compiler/xla/tests/tuple_test.cc
+++ b/tensorflow/compiler/xla/tests/tuple_test.cc
@@ -49,10 +49,10 @@ XLA_TEST_F(TupleTest, TupleConstant) {
{1.1f, 2.2f, 3.5f}, // row 0
{4.8f, 5.0f, 6.7f}, // row 1
};
- auto value =
- Literal::MakeTuple({Literal::CreateR0<float>(constant_scalar).get(),
- Literal::CreateR1<float>(constant_vector).get(),
- Literal::CreateR2<float>(constant_matrix).get()});
+ auto value = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR0<float>(constant_scalar).get(),
+ LiteralUtil::CreateR1<float>(constant_vector).get(),
+ LiteralUtil::CreateR2<float>(constant_matrix).get()});
ConstantLiteral(&builder, *value);
ComputeAndCompareTuple(&builder, *value, {}, error_spec_);
@@ -64,9 +64,9 @@ XLA_TEST_F(TupleTest, TupleScalarConstant) {
const float constant_scalar1 = 7.3f;
const float constant_scalar2 = 1.2f;
- auto value =
- Literal::MakeTuple({Literal::CreateR0<float>(constant_scalar1).get(),
- Literal::CreateR0<float>(constant_scalar2).get()});
+ auto value = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR0<float>(constant_scalar1).get(),
+ LiteralUtil::CreateR0<float>(constant_scalar2).get()});
ConstantLiteral(&builder, *value);
ComputeAndCompareTuple(&builder, *value, {}, error_spec_);
@@ -86,10 +86,10 @@ XLA_TEST_F(TupleTest, TupleCreate) {
ConstantR1<float>(&builder, constant_vector),
ConstantR2<float>(&builder, constant_matrix)});
- auto expected =
- Literal::MakeTuple({Literal::CreateR0<float>(constant_scalar).get(),
- Literal::CreateR1<float>(constant_vector).get(),
- Literal::CreateR2<float>(constant_matrix).get()});
+ auto expected = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR0<float>(constant_scalar).get(),
+ LiteralUtil::CreateR1<float>(constant_vector).get(),
+ LiteralUtil::CreateR2<float>(constant_matrix).get()});
ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
}
@@ -100,8 +100,9 @@ XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) {
Tuple(&builder,
{ConstantR0<float>(&builder, 7.0), ConstantR1<float>(&builder, {})});
- auto expected = Literal::MakeTuple({Literal::CreateR0<float>(7.0).get(),
- Literal::CreateR1<float>({}).get()});
+ auto expected =
+ LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(7.0).get(),
+ LiteralUtil::CreateR1<float>({}).get()});
ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
}
@@ -109,7 +110,7 @@ XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) {
XLA_TEST_F(TupleTest, EmptyTupleCreate) {
XlaBuilder builder(TestName());
Tuple(&builder, {});
- auto expected = Literal::MakeTuple({});
+ auto expected = LiteralUtil::MakeTuple({});
ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
}
@@ -193,9 +194,9 @@ XLA_TEST_F(TupleTest, TupleGTEToTuple) {
ConstantR2<float>(&builder, constant_matrix)});
Tuple(&builder,
{GetTupleElement(tuple_data, 1), GetTupleElement(tuple_data, 0)});
- auto expected =
- Literal::MakeTuple({Literal::CreateR2<float>(constant_matrix).get(),
- Literal::CreateR1<float>(constant_vector).get()});
+ auto expected = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR2<float>(constant_matrix).get(),
+ LiteralUtil::CreateR1<float>(constant_vector).get()});
ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
}
@@ -216,8 +217,8 @@ XLA_TEST_F(TupleTest, SelectBetweenPredTuples) {
auto v2_v1 = Tuple(&b, {v2_gt, v1_gt}); // {true, false}
Select(direction ? v1_gt : v2_gt, v1_v2, v2_v1);
auto expected =
- Literal::MakeTuple({Literal::CreateR0<bool>(direction).get(),
- Literal::CreateR0<bool>(!direction).get()});
+ LiteralUtil::MakeTuple({LiteralUtil::CreateR0<bool>(direction).get(),
+ LiteralUtil::CreateR0<bool>(!direction).get()});
ComputeAndCompareTuple(&b, *expected, {v1_data.get(), v2_data.get()},
error_spec_);
@@ -284,8 +285,9 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesOnFalse) {
ConstantR1<float>(&builder, vec1)});
Select(ConstantR0<bool>(&builder, false), tuple12, tuple21);
- auto expected = Literal::MakeTuple({Literal::CreateR1<float>(vec2).get(),
- Literal::CreateR1<float>(vec1).get()});
+ auto expected =
+ LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>(vec2).get(),
+ LiteralUtil::CreateR1<float>(vec1).get()});
ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
}
@@ -328,8 +330,9 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesOnTrue) {
ConstantR1<float>(&builder, vec1)});
Select(ConstantR0<bool>(&builder, true), tuple12, tuple21);
- auto expected = Literal::MakeTuple({Literal::CreateR1<float>(vec1).get(),
- Literal::CreateR1<float>(vec2).get()});
+ auto expected =
+ LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>(vec1).get(),
+ LiteralUtil::CreateR1<float>(vec2).get()});
ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
}
@@ -403,8 +406,9 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesReuseConstants) {
Select(ConstantR0<bool>(&builder, false), tuple12, tuple21);
- auto expected = Literal::MakeTuple({Literal::CreateR1<float>(vec2).get(),
- Literal::CreateR1<float>(vec1).get()});
+ auto expected =
+ LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>(vec2).get(),
+ LiteralUtil::CreateR1<float>(vec1).get()});
ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
}
@@ -414,13 +418,13 @@ XLA_TEST_F(TupleTest, NestedTuples) {
ConstantR0<float>(&builder, 42.0)});
Tuple(&builder, {inner_tuple, ConstantR1<float>(&builder, {22.0, 44.0})});
- auto expected_v1 = Literal::CreateR1<float>({1.0, 2.0});
- auto expected_s = Literal::CreateR0<float>(42.0);
+ auto expected_v1 = LiteralUtil::CreateR1<float>({1.0, 2.0});
+ auto expected_s = LiteralUtil::CreateR0<float>(42.0);
auto expected_inner_tuple =
- Literal::MakeTuple({expected_v1.get(), expected_s.get()});
- auto expected_v2 = Literal::CreateR1<float>({22.0, 44.0});
+ LiteralUtil::MakeTuple({expected_v1.get(), expected_s.get()});
+ auto expected_v2 = LiteralUtil::CreateR1<float>({22.0, 44.0});
auto expected =
- Literal::MakeTuple({expected_inner_tuple.get(), expected_v2.get()});
+ LiteralUtil::MakeTuple({expected_inner_tuple.get(), expected_v2.get()});
ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
}
@@ -440,14 +444,14 @@ XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) {
std::unique_ptr<GlobalData> data =
client_
- ->TransferToServer(*Literal::MakeTuple({
- Literal::MakeTuple(
+ ->TransferToServer(*LiteralUtil::MakeTuple({
+ LiteralUtil::MakeTuple(
{
- Literal::CreateR1<float>({1.0, 2.0, 3.0}).get(),
- Literal::CreateR1<float>({4.0, 5.0, 6.0}).get(),
+ LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0}).get(),
+ LiteralUtil::CreateR1<float>({4.0, 5.0, 6.0}).get(),
})
.get(),
- Literal::CreateR1<float>({7.0, 8.0, 9.0}).get(),
+ LiteralUtil::CreateR1<float>({7.0, 8.0, 9.0}).get(),
}))
.ConsumeValueOrDie();
@@ -478,11 +482,12 @@ XLA_TEST_F(TupleTest, ComplexTuples) {
std::unique_ptr<GlobalData> arg0 =
client_
- ->TransferToServer(*Literal::MakeTuple(
- {Literal::CreateR0<complex64>({1, 2}).get(),
- Literal::MakeTuple(
- {Literal::CreateR1<complex64>({{10, 20}, {30, 40}}).get(),
- Literal::CreateR2<complex64>(
+ ->TransferToServer(*LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR0<complex64>({1, 2}).get(),
+ LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR1<complex64>({{10, 20}, {30, 40}})
+ .get(),
+ LiteralUtil::CreateR2<complex64>(
{{{100, 200}, {300, 400}},
{{1000, 2000}, {3000, 4000}},
{{10000, 20000}, {30000, 40000}}})
@@ -491,11 +496,13 @@ XLA_TEST_F(TupleTest, ComplexTuples) {
.ConsumeValueOrDie();
std::unique_ptr<GlobalData> arg1 =
client_
- ->TransferToServer(*Literal::CreateR1<complex64>({{1, 2}, {1, -2}}))
+ ->TransferToServer(
+ *LiteralUtil::CreateR1<complex64>({{1, 2}, {1, -2}}))
.ConsumeValueOrDie();
- auto sum = Literal::CreateR2<complex64>({{{111, 222}, {331, 442}},
- {{1011, 2022}, {3031, 4042}},
- {{10011, 20022}, {30031, 40042}}});
+ auto sum =
+ LiteralUtil::CreateR2<complex64>({{{111, 222}, {331, 442}},
+ {{1011, 2022}, {3031, 4042}},
+ {{10011, 20022}, {30031, 40042}}});
auto prod = MakeUnique<Literal>(sum->shape());
ASSERT_TRUE(prod->Populate<complex64>(
[&sum](tensorflow::gtl::ArraySlice<int64> indexes) {
@@ -505,9 +512,9 @@ XLA_TEST_F(TupleTest, ComplexTuples) {
: complex64(1, -2));
})
.ok());
- auto expected =
- Literal::MakeTuple({Literal::MakeTuple({prod.get(), sum.get()}).get(),
- Literal::CreateR0<complex64>({123, 456}).get()});
+ auto expected = LiteralUtil::MakeTuple(
+ {LiteralUtil::MakeTuple({prod.get(), sum.get()}).get(),
+ LiteralUtil::CreateR0<complex64>({123, 456}).get()});
ComputeAndCompareTuple(&builder, *expected, {arg0.get(), arg1.get()},
error_spec_);
}
@@ -530,10 +537,11 @@ XLA_TEST_F(TupleHloTest, DISABLED_ON_INTERPRETER(BitcastAfterGTE)) {
auto module =
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
- auto param = Literal::MakeTupleOwned(Literal::CreateR1<float>({1, 2, 3}));
+ auto param =
+ LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({1, 2, 3}));
auto result = ExecuteNoHloPasses(std::move(module), {param.get()});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *Literal::MakeTupleOwned(Literal::CreateR2<float>({{1, 2, 3}})),
+ *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR2<float>({{1, 2, 3}})),
*result));
}
diff --git a/tensorflow/compiler/xla/tests/unary_op_test.cc b/tensorflow/compiler/xla/tests/unary_op_test.cc
index 929b1ca7fb..a90a6fb0a5 100644
--- a/tensorflow/compiler/xla/tests/unary_op_test.cc
+++ b/tensorflow/compiler/xla/tests/unary_op_test.cc
@@ -101,7 +101,7 @@ void UnaryOpTest::AbsTestHelper<complex64>() {
Abs(arg);
std::unique_ptr<Literal> expected =
- Literal::CreateR1<float>({2, 25, 0, 0.5, inf<float>(), inf<float>()});
+ LiteralUtil::CreateR1<float>({2, 25, 0, 0.5, inf<float>(), inf<float>()});
ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f));
}
@@ -113,7 +113,7 @@ void UnaryOpTest::SignTestHelper<complex64>() {
{{-2, 0}, {0, 25}, {0, 0}, {static_cast<float>(-0.0), 0}, {-1, 1}});
Sign(arg);
- std::unique_ptr<Literal> expected = Literal::CreateR1<complex64>(
+ std::unique_ptr<Literal> expected = LiteralUtil::CreateR1<complex64>(
{{-1, 0}, {0, 1}, {0, 0}, {0, 0}, {-std::sqrt(0.5f), std::sqrt(0.5f)}});
ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f));
}
@@ -128,7 +128,7 @@ void UnaryOpTest::SignAbsTestHelper<complex64>() {
Sub(Mul(sign, ConvertElementType(abs, C64)), arg);
std::unique_ptr<Literal> expected =
- Literal::CreateR1<complex64>({0, 0, 0, 0});
+ LiteralUtil::CreateR1<complex64>({0, 0, 0, 0});
ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f));
}
@@ -173,7 +173,7 @@ XLA_TEST_F(UnaryOpTest, SignTestR0) {
Add(Add(sgnf0, sgnf), ConvertElementType(sgni, F32)), C64));
std::unique_ptr<Literal> expected =
- Literal::CreateR0<complex64>({-2.6f, 0.8f});
+ LiteralUtil::CreateR0<complex64>({-2.6f, 0.8f});
ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f));
}
diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc
index bbd67cd8d7..29befef92e 100644
--- a/tensorflow/compiler/xla/tests/while_test.cc
+++ b/tensorflow/compiler/xla/tests/while_test.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -347,8 +347,8 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) {
// the sum will increase by 1.0. It will first be >15.5 when the elements
// have all reached 2.0.
auto expected_data =
- Literal::CreateR1<float>({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f});
- auto expected = Literal::MakeTuple({expected_data.get()});
+ LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f});
+ auto expected = LiteralUtil::MakeTuple({expected_data.get()});
VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
}
@@ -397,12 +397,13 @@ TEST_F(WhileTest, WhileWithPermutationAndTupleResult) {
<< ShapeUtil::HumanString(
builder.GetShape(result).ConsumeValueOrDie());
- auto expected_counter = Literal::CreateR0<int32>(N);
- auto expected_w1 = Literal::CreateR1<float>({1.0f, 1.0f, 1.0f});
- auto expected_w2 = Literal::CreateR1<float>({2.0f, 2.0f, 2.0f});
- auto expected_w3 = Literal::CreateR1<float>({3.0f, 3.0f, 3.0f});
- auto expected = Literal::MakeTuple({expected_counter.get(), expected_w2.get(),
- expected_w3.get(), expected_w1.get()});
+ auto expected_counter = LiteralUtil::CreateR0<int32>(N);
+ auto expected_w1 = LiteralUtil::CreateR1<float>({1.0f, 1.0f, 1.0f});
+ auto expected_w2 = LiteralUtil::CreateR1<float>({2.0f, 2.0f, 2.0f});
+ auto expected_w3 = LiteralUtil::CreateR1<float>({3.0f, 3.0f, 3.0f});
+ auto expected =
+ LiteralUtil::MakeTuple({expected_counter.get(), expected_w2.get(),
+ expected_w3.get(), expected_w1.get()});
VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
}
@@ -506,11 +507,11 @@ TEST_F(WhileTest, WhileWithTupleResult) {
<< ShapeUtil::HumanString(
builder.GetShape(result).ConsumeValueOrDie());
- auto expected_counter = Literal::CreateR0<int32>(5);
- auto expected_data = Literal::CreateR1<float>(
+ auto expected_counter = LiteralUtil::CreateR0<int32>(5);
+ auto expected_data = LiteralUtil::CreateR1<float>(
{5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f});
auto expected =
- Literal::MakeTuple({expected_counter.get(), expected_data.get()});
+ LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()});
VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
}
@@ -554,10 +555,10 @@ TEST_F(WhileTest, WhileWithPredicateTupleResult) {
<< ShapeUtil::HumanString(
builder.GetShape(result).ConsumeValueOrDie());
- auto expected_counter = Literal::CreateR0<int32>(5);
- auto expected_predicate = Literal::CreateR0<bool>(true);
- auto expected =
- Literal::MakeTuple({expected_counter.get(), expected_predicate.get()});
+ auto expected_counter = LiteralUtil::CreateR0<int32>(5);
+ auto expected_predicate = LiteralUtil::CreateR0<bool>(true);
+ auto expected = LiteralUtil::MakeTuple(
+ {expected_counter.get(), expected_predicate.get()});
ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0));
}
@@ -599,10 +600,10 @@ TEST_F(WhileTest, WhileWithTupleConstantScalarResult) {
<< ShapeUtil::HumanString(
builder.GetShape(result).ConsumeValueOrDie());
- auto expected_counter = Literal::CreateR0<int32>(5);
- auto expected_data = Literal::CreateR0<int32>(7);
+ auto expected_counter = LiteralUtil::CreateR0<int32>(5);
+ auto expected_data = LiteralUtil::CreateR0<int32>(7);
auto expected =
- Literal::MakeTuple({expected_counter.get(), expected_data.get()});
+ LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()});
VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
}
@@ -882,11 +883,11 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) {
<< ShapeUtil::HumanString(
builder.GetShape(result).ConsumeValueOrDie());
- auto expected_counter = Literal::CreateR0<int32>(5);
- auto expected_data = Literal::CreateR1<float>(
+ auto expected_counter = LiteralUtil::CreateR0<int32>(5);
+ auto expected_data = LiteralUtil::CreateR1<float>(
{1.0f, 1.0f, 2.0f, 2.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f});
auto expected =
- Literal::MakeTuple({expected_counter.get(), expected_data.get()});
+ LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()});
VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
}
@@ -974,12 +975,12 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithTupleElement) {
TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build());
While(cond_computation, body_computation, t);
- auto expected_element = Literal::CreateR1<float>({1, 1});
+ auto expected_element = LiteralUtil::CreateR1<float>({1, 1});
auto expected =
- Literal::MakeTuple({expected_element.get(), expected_element.get()});
+ LiteralUtil::MakeTuple({expected_element.get(), expected_element.get()});
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> parameter_data,
- client_->TransferToServer(*Literal::CreateR1<float>({42, 42})));
+ client_->TransferToServer(*LiteralUtil::CreateR1<float>({42, 42})));
ComputeAndCompareTuple(&outer, *expected, {parameter_data.get()},
ErrorSpec(1e-6));
}
@@ -1004,7 +1005,7 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithBroadcast) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> parameter_data,
- client_->TransferToServer(*Literal::CreateR1<float>({42, 42})));
+ client_->TransferToServer(*LiteralUtil::CreateR1<float>({42, 42})));
ComputeAndCompareR1<float>(&outer, {1.0f, 1.0f}, {parameter_data.get()},
ErrorSpec(1e-6));
}
@@ -1030,7 +1031,7 @@ TEST_F(WhileTest, WhileThatTurnsScalarParameterToTupleElement) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> parameter_data,
- client_->TransferToServer(*Literal::CreateR0<float>(42)));
+ client_->TransferToServer(*LiteralUtil::CreateR0<float>(42)));
ComputeAndCompareR0<float>(&outer, 43.0f, {parameter_data.get()},
ErrorSpec(1e-6));
}
@@ -1069,11 +1070,11 @@ TEST_F(WhileTest, WhileWithMixedTupleElements) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> parameter_data,
- client_->TransferToServer(*Literal::CreateR0<int32>(1)));
+ client_->TransferToServer(*LiteralUtil::CreateR0<int32>(1)));
- auto add1 = Literal::CreateR0<int32>(15);
- auto add2 = Literal::CreateR0<int32>(16);
- auto expected = Literal::MakeTuple({add1.get(), add2.get()});
+ auto add1 = LiteralUtil::CreateR0<int32>(15);
+ auto add2 = LiteralUtil::CreateR0<int32>(16);
+ auto expected = LiteralUtil::MakeTuple({add1.get(), add2.get()});
ComputeAndCompareTuple(&outer, *expected, {parameter_data.get()},
ErrorSpec(1e-6));
}
@@ -1226,9 +1227,9 @@ TEST_F(WhileTest, WhileWithLoopInvariantOperation) {
auto while_instruction = While(condition, body, init);
GetTupleElement(while_instruction, 3);
- TF_ASSERT_OK_AND_ASSIGN(auto param_value,
- client_->TransferToServer(*Literal::CreateR2<float>(
- {{1.0, 2.0}, {-1.0, -2.0}})));
+ TF_ASSERT_OK_AND_ASSIGN(
+ auto param_value, client_->TransferToServer(*LiteralUtil::CreateR2<float>(
+ {{1.0, 2.0}, {-1.0, -2.0}})));
ComputeAndCompareR2<float>(
&builder, {{-0.76159416, -0.96402758}, {0.76159416, 0.96402758}},
diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
index 7dba058d40..4d4dd62a3f 100644
--- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
+++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
@@ -79,7 +79,9 @@ struct ParsedProfileOutputLine {
Status ParseOneProfileOutputLine(
const string& line, bool expect_hlo,
- gtl::FlatMap<string, ParsedProfileOutputLine>* parsed_results) {
+ gtl::FlatMap<string, ParsedProfileOutputLine>* parsed_results,
+ tensorflow::gtl::ArraySlice<tensorflow::StringPiece> opcodes_to_ignore =
+ {}) {
string separator = "[^:]*:: +";
string match_percentage = "\\d+\\.\\d\\d%";
string match_cycles = "(\\d+) cycles +\\( *(" + match_percentage + ")\\)";
@@ -113,7 +115,9 @@ Status ParseOneProfileOutputLine(
", Regexp: ", regexp_pattern);
}
- InsertOrDie(parsed_results, parsed_line.opcode, parsed_line);
+ if (!c_linear_search(opcodes_to_ignore, parsed_line.opcode)) {
+ InsertOrDie(parsed_results, parsed_line.opcode, parsed_line);
+ }
return Status::OK();
}
@@ -168,7 +172,6 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client,
auto execution_result,
executable->ExecuteOnStream(&run_options, {&lhs_arg, &rhs_arg},
&hlo_execution_profile));
- TF_ASSERT_OK(stream_ptr->BlockHostUntilDone());
(void)execution_result;
*profile_output =
@@ -267,7 +270,7 @@ XLA_TEST_F(HloProfileTest, ProfileWhileComputation) {
auto matrix = GetTupleElement(state, 1);
auto next_iteration =
Add(GetTupleElement(state, 0), ConstantR0<int32>(&builder, 1));
- Tuple(&builder, {next_iteration, Add(matrix, matrix)});
+ Tuple(&builder, {next_iteration, Mul(matrix, matrix)});
TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
}
@@ -289,36 +292,50 @@ XLA_TEST_F(HloProfileTest, ProfileWhileComputation) {
tensorflow::str_util::Split(profile_output, '\n');
auto while_body_profile_start =
- std::find_if(profile_output_lines.begin(), profile_output_lines.end(),
+ c_find_if(profile_output_lines, [](tensorflow::StringPiece s) {
+ return tensorflow::str_util::StartsWith(s,
+ "Execution profile for body");
+ });
+
+ ASSERT_NE(while_body_profile_start, profile_output_lines.cend());
+
+ auto while_body_profile_end =
+ std::find_if(while_body_profile_start, profile_output_lines.end(),
[](tensorflow::StringPiece s) {
return tensorflow::str_util::StartsWith(
- s, "Execution profile for body");
+ s, "********** microseconds report **********");
});
- ASSERT_NE(while_body_profile_start, profile_output_lines.end());
+ // We emit a blank line before the "********** microseconds report **********"
+ // line.
+ while_body_profile_end--;
- gtl::FlatMap<string, ParsedProfileOutputLine> parsed_profile_lines;
+ ASSERT_NE(while_body_profile_end, profile_output_lines.end());
- TF_ASSERT_OK(
- ParseOneProfileOutputLine(*std::next(while_body_profile_start, 1),
- /*expect_hlo=*/false, &parsed_profile_lines));
+ gtl::FlatMap<string, ParsedProfileOutputLine> parsed_profile_lines;
- TF_ASSERT_OK(
- ParseOneProfileOutputLine(*std::next(while_body_profile_start, 2),
- /*expect_hlo=*/true, &parsed_profile_lines));
+ for (auto while_body_profile_i = while_body_profile_start + 1;
+ while_body_profile_i != while_body_profile_end; while_body_profile_i++) {
+ // There are multiple "get-tuple-element" instructions in the while body so
+ // we ignore them -- we don't want parsed_profile_lines to be a multi-map.
+ TF_ASSERT_OK(ParseOneProfileOutputLine(
+ *while_body_profile_i,
+ /*expect_hlo=*/while_body_profile_i != (while_body_profile_start + 1),
+ &parsed_profile_lines, {"get-tuple-element"}));
+ }
TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine total_while_body_profile,
MaybeFind(parsed_profile_lines, "[total]"));
- TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine dot_profile,
- MaybeFind(parsed_profile_lines, "add"));
+ TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine multiply_profile,
+ MaybeFind(parsed_profile_lines, "multiply"));
EXPECT_GT(total_while_body_profile.cycles, 0);
EXPECT_EQ(total_while_body_profile.opcode, "[total]");
EXPECT_EQ(total_while_body_profile.cycles_percentage, "100.00%");
- EXPECT_GT(total_while_body_profile.cycles, dot_profile.cycles);
- EXPECT_NE(dot_profile.cycles_percentage, "0.00%");
- EXPECT_NE(dot_profile.cycles_percentage, "100.00%");
+ EXPECT_GT(total_while_body_profile.cycles, multiply_profile.cycles);
+ EXPECT_NE(multiply_profile.cycles_percentage, "0.00%");
+ EXPECT_NE(multiply_profile.cycles_percentage, "100.00%");
}
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/text_literal_reader.cc b/tensorflow/compiler/xla/text_literal_reader.cc
index 56702feab9..897123d760 100644
--- a/tensorflow/compiler/xla/text_literal_reader.cc
+++ b/tensorflow/compiler/xla/text_literal_reader.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include <utility>
#include <vector>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
diff --git a/tensorflow/compiler/xla/text_literal_reader.h b/tensorflow/compiler/xla/text_literal_reader.h
index e45e5291c9..708e8c80d8 100644
--- a/tensorflow/compiler/xla/text_literal_reader.h
+++ b/tensorflow/compiler/xla/text_literal_reader.h
@@ -18,7 +18,7 @@ limitations under the License.
#include <memory>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
diff --git a/tensorflow/compiler/xla/text_literal_reader_test.cc b/tensorflow/compiler/xla/text_literal_reader_test.cc
index 23070b6638..92f9b4f9f0 100644
--- a/tensorflow/compiler/xla/text_literal_reader_test.cc
+++ b/tensorflow/compiler/xla/text_literal_reader_test.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <string>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/types.h"
diff --git a/tensorflow/compiler/xla/text_literal_writer.cc b/tensorflow/compiler/xla/text_literal_writer.cc
index 373c0d2d8d..24e0784741 100644
--- a/tensorflow/compiler/xla/text_literal_writer.cc
+++ b/tensorflow/compiler/xla/text_literal_writer.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <string>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
diff --git a/tensorflow/compiler/xla/text_literal_writer.h b/tensorflow/compiler/xla/text_literal_writer.h
index 0a1235b5e0..159ac1b7e1 100644
--- a/tensorflow/compiler/xla/text_literal_writer.h
+++ b/tensorflow/compiler/xla/text_literal_writer.h
@@ -16,7 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_
#define TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
diff --git a/tensorflow/compiler/xla/text_literal_writer_test.cc b/tensorflow/compiler/xla/text_literal_writer_test.cc
index 70cf2fb1b8..4ea02faffc 100644
--- a/tensorflow/compiler/xla/text_literal_writer_test.cc
+++ b/tensorflow/compiler/xla/text_literal_writer_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
#include <string>
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
@@ -30,8 +31,9 @@ namespace xla {
namespace {
TEST(TextLiteralWriterTest, WritesFloatLiteral) {
- auto literal = Literal::CreateR2<float>({
- {3.14, 2.17}, {1.23, 4.56},
+ auto literal = LiteralUtil::CreateR2<float>({
+ {3.14, 2.17},
+ {1.23, 4.56},
});
string path =
tensorflow::io::JoinPath(tensorflow::testing::TmpDir(), "/whatever");
diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD
index e4a052c8f1..55501827f2 100644
--- a/tensorflow/compiler/xla/tools/BUILD
+++ b/tensorflow/compiler/xla/tools/BUILD
@@ -74,7 +74,7 @@ cc_library(
srcs = ["replay_computation.cc"],
deps = [
"//tensorflow/compiler/xla:execution_options_util",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
@@ -123,7 +123,7 @@ tf_cc_binary(
name = "show_literal",
srcs = ["show_literal.cc"],
deps = [
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
@@ -145,7 +145,7 @@ tf_cc_binary(
name = "show_text_literal",
srcs = ["show_text_literal.cc"],
deps = [
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:text_literal_reader",
"//tensorflow/compiler/xla:types",
diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc
index 3a7917cf30..854e797ec2 100644
--- a/tensorflow/compiler/xla/tools/replay_computation.cc
+++ b/tensorflow/compiler/xla/tools/replay_computation.cc
@@ -43,7 +43,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/lib/testing.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/gpu/infeed_manager.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
diff --git a/tensorflow/compiler/xla/tools/show_literal.cc b/tensorflow/compiler/xla/tools/show_literal.cc
index fe8e72ba32..51909190a3 100644
--- a/tensorflow/compiler/xla/tools/show_literal.cc
+++ b/tensorflow/compiler/xla/tools/show_literal.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include <stdio.h>
#include <string>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
diff --git a/tensorflow/compiler/xla/tools/show_text_literal.cc b/tensorflow/compiler/xla/tools/show_text_literal.cc
index 8525873e91..48c8374811 100644
--- a/tensorflow/compiler/xla/tools/show_text_literal.cc
+++ b/tensorflow/compiler/xla/tools/show_text_literal.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include <memory>
#include <string>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/text_literal_reader.h"
#include "tensorflow/compiler/xla/types.h"
diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h
index b23b968aae..5ae099a462 100644
--- a/tensorflow/compiler/xla/util.h
+++ b/tensorflow/compiler/xla/util.h
@@ -500,17 +500,17 @@ bool c_is_sorted(const C& c, Compare&& comp) {
}
template <typename C>
-auto c_adjacent_find(const C& c) -> decltype(std::begin(c)) {
+auto c_adjacent_find(C& c) -> decltype(std::begin(c)) {
return std::adjacent_find(std::begin(c), std::end(c));
}
template <typename C, typename Pred>
-auto c_find_if(const C& c, Pred&& pred) -> decltype(std::begin(c)) {
+auto c_find_if(C& c, Pred&& pred) -> decltype(std::begin(c)) {
return std::find_if(std::begin(c), std::end(c), std::forward<Pred>(pred));
}
template <typename C, typename Value>
-auto c_find(const C& c, Value&& value) -> decltype(std::begin(c)) {
+auto c_find(C& c, Value&& value) -> decltype(std::begin(c)) {
return std::find(std::begin(c), std::end(c), std::forward<Value>(value));
}
@@ -562,6 +562,11 @@ void EraseAt(C* c, int64 index) {
c->erase(c->begin() + index);
}
+template <typename T>
+std::vector<T> ArraySliceToVector(tensorflow::gtl::ArraySlice<T> slice) {
+ return std::vector<T>(slice.begin(), slice.end());
+}
+
template <typename T, int N>
std::vector<T> InlinedVectorToVector(
const tensorflow::gtl::InlinedVector<T, N>& inlined_vector) {
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index c039624daa..60be9db263 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -27,7 +27,6 @@ py_library(
"//tensorflow/contrib/bayesflow:bayesflow_py",
"//tensorflow/contrib/boosted_trees:init_py",
"//tensorflow/contrib/checkpoint/python:checkpoint",
- "//tensorflow/contrib/cluster_resolver:cluster_resolver_pip",
"//tensorflow/contrib/cluster_resolver:cluster_resolver_py",
"//tensorflow/contrib/coder:coder_py",
"//tensorflow/contrib/compiler:compiler_py",
diff --git a/tensorflow/contrib/autograph/README.md b/tensorflow/contrib/autograph/README.md
index 7e26f47118..679ab48e5c 100644
--- a/tensorflow/contrib/autograph/README.md
+++ b/tensorflow/contrib/autograph/README.md
@@ -4,7 +4,7 @@ IMPORTANT: AutoGraph is alpha software, and under active development. Expect rou
AutoGraph is a Python to TensorFlow compiler.
-With AutoGraph, you can write [Eager style](https://www.tensorflow.org/guide/eager) code in a concise manner, and run it as a TensorFlow graph. AutoGraph uses source code transformation and partial evaluation to generate Python code that builds an equivalent TensorFlow subgraph. The result is code that behaves like ops and can be freely combined with other TensorFlow ops.
+With AutoGraph, you can write [Eager style](https://www.tensorflow.org/guide/eager) code in a concise manner, and run it as a TensorFlow graph. AutoGraph uses source code transformation and partial evaluation to generate Python code that builds an equivalent TensorFlow subgraph. The result is code that behaves like ops and can be freely combined with other TensorFlow ops. [Please see this file for which parts of the Python language we currently support](LIMITATIONS.md).
For example, this Python function:
diff --git a/tensorflow/contrib/autograph/__init__.py b/tensorflow/contrib/autograph/__init__.py
index 361cf2d77c..7821c98f1c 100644
--- a/tensorflow/contrib/autograph/__init__.py
+++ b/tensorflow/contrib/autograph/__init__.py
@@ -29,6 +29,9 @@ from tensorflow.contrib.autograph.impl.api import converted_call
from tensorflow.contrib.autograph.impl.api import do_not_convert
from tensorflow.contrib.autograph.impl.api import RunMode
from tensorflow.contrib.autograph.impl.api import to_code
+from tensorflow.contrib.autograph.core.errors import improved_errors
+from tensorflow.contrib.autograph.core.errors import GraphConstructionError
+from tensorflow.contrib.autograph.core.errors import TfRuntimeError
from tensorflow.contrib.autograph.impl.api import to_graph
from tensorflow.contrib.autograph.lang.directives import set_element_type
from tensorflow.contrib.autograph.lang.directives import set_loop_options
@@ -46,6 +49,10 @@ _allowed_symbols = [
'to_graph',
# Overloaded operators
'operators',
+ # Errors
+ 'improved_errors',
+ 'GraphConstructionError',
+ 'TfRuntimeError',
# Python language "extensions"
'set_element_type',
'set_loop_options',
diff --git a/tensorflow/contrib/autograph/converters/BUILD b/tensorflow/contrib/autograph/converters/BUILD
index b2e2e27673..33d8d517a5 100644
--- a/tensorflow/contrib/autograph/converters/BUILD
+++ b/tensorflow/contrib/autograph/converters/BUILD
@@ -24,6 +24,7 @@ py_library(
"continue_statements.py",
"control_flow.py",
"decorators.py",
+ "error_handlers.py",
"ifexp.py",
"list_comprehension.py",
"lists.py",
@@ -216,6 +217,18 @@ py_test(
)
py_test(
+ name = "error_handlers_test",
+ srcs = ["error_handlers_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":converters",
+ "//tensorflow/contrib/autograph/core:test_lib",
+ "//tensorflow/contrib/autograph/pyct",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+py_test(
name = "slices_test",
srcs = ["slices_test.py"],
srcs_version = "PY2AND3",
diff --git a/tensorflow/contrib/autograph/converters/__init__.py b/tensorflow/contrib/autograph/converters/__init__.py
index e4e8eda42f..6325ac78dc 100644
--- a/tensorflow/contrib/autograph/converters/__init__.py
+++ b/tensorflow/contrib/autograph/converters/__init__.py
@@ -18,5 +18,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-# TODO(mdan): Define a base transformer class that can recognize skip_processing
-# TODO(mdan): All converters are incomplete, especially those that change blocks
+# Naming conventions:
+# * each converter should specialize on a single idiom; be consistent with
+# the Python reference for naming
+# * all converters inherit core.converter.Base
+# * module names describe the idiom that the converter covers, plural
+# * the converter class is named consistent with the module, singular and
+# includes the word Transformer
+#
+# Example:
+#
+# lists.py
+# class ListTransformer(converter.Base)
diff --git a/tensorflow/contrib/autograph/converters/error_handlers.py b/tensorflow/contrib/autograph/converters/error_handlers.py
new file mode 100644
index 0000000000..3f23662152
--- /dev/null
+++ b/tensorflow/contrib/autograph/converters/error_handlers.py
@@ -0,0 +1,52 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Wraps function bodies with a try/except to rewrite error tracebacks.
+
+Only adds try/except wrappers to functions that have the anno.Basic.ORIGIN
+annotation because these are the functions originally written by the user.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.autograph.core import converter
+from tensorflow.contrib.autograph.pyct import anno
+from tensorflow.contrib.autograph.pyct import templates
+
+
+class ErrorRewritingTransformer(converter.Base):
+ """Possibly wraps the body of a function in a try/except.
+
+ Only wraps functions that were originally defined by the user, detected by
+ checking for the anno.Basic.ORIGIN annotation.
+ """
+
+ def visit_FunctionDef(self, node):
+ node = self.generic_visit(node)
+
+ if anno.hasanno(node, anno.Basic.ORIGIN):
+ template = """
+ try:
+ body
+ except:
+ ag__.rewrite_graph_construction_error(ag_source_map__)
+ """
+ node.body = templates.replace(template, body=node.body)
+ return node
+
+
+def transform(node, ctx):
+ return ErrorRewritingTransformer(ctx).visit(node)
diff --git a/tensorflow/contrib/autograph/converters/error_handlers_test.py b/tensorflow/contrib/autograph/converters/error_handlers_test.py
new file mode 100644
index 0000000000..408e35b4b6
--- /dev/null
+++ b/tensorflow/contrib/autograph/converters/error_handlers_test.py
@@ -0,0 +1,61 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for error_handlers module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.autograph.converters import error_handlers
+from tensorflow.contrib.autograph.core import converter_testing
+from tensorflow.contrib.autograph.core import errors
+from tensorflow.contrib.autograph.pyct import anno
+from tensorflow.contrib.autograph.pyct import origin_info
+from tensorflow.python.platform import test
+
+
+class ErrorHandlersTest(converter_testing.TestCase):
+
+ def compiled_fn(self, test_fn, add_origin=False):
+ node = self.parse_and_analyze(test_fn, {})
+ if add_origin:
+ anno.setanno(node.body[0], anno.Basic.ORIGIN,
+ origin_info.OriginInfo(__file__, None, None, None, None))
+ node = error_handlers.transform(node, self.ctx)
+ module = self.compiled(node,)
+ return module
+
+ def test_no_origin_annotation(self):
+
+ def test_fn():
+ raise ValueError('Crash!')
+
+ with self.compiled_fn(test_fn) as result:
+ with self.assertRaises(ValueError):
+ result.test_fn()
+
+ def test_wraps_body(self):
+
+ def test_fn():
+ raise ValueError('Crash!')
+
+ with self.compiled_fn(test_fn, add_origin=True) as result:
+ result.rewrite_graph_construction_error = None
+ with self.assertRaises(errors.GraphConstructionError):
+ result.test_fn()
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/autograph/converters/single_return.py b/tensorflow/contrib/autograph/converters/single_return.py
index a351cd81b8..3b9c9a06d8 100644
--- a/tensorflow/contrib/autograph/converters/single_return.py
+++ b/tensorflow/contrib/autograph/converters/single_return.py
@@ -224,11 +224,6 @@ class DetectReturnInUnsupportedControlFlow(gast.NodeVisitor):
self.generic_visit(node)
self.cant_return = False
- def visit_Try(self, node):
- self.cant_return = True
- self.generic_visit(node)
- self.cant_return = False
-
def visit_Return(self, node):
if self.cant_return:
raise ValueError(
diff --git a/tensorflow/contrib/autograph/converters/slices.py b/tensorflow/contrib/autograph/converters/slices.py
index 3f5fc57125..de04cc9184 100644
--- a/tensorflow/contrib/autograph/converters/slices.py
+++ b/tensorflow/contrib/autograph/converters/slices.py
@@ -56,8 +56,7 @@ class SliceTransformer(converter.Base):
def visit_Subscript(self, node):
node = self.generic_visit(node)
if not isinstance(node.slice, gast.Index):
- # TODO(mdan): It might make more sense to wave them through.
- raise NotImplementedError('non-index slice')
+ return node
if not isinstance(node.ctx, gast.Load):
# Index writes are handled at a higher level, one at which the rvalue is
diff --git a/tensorflow/contrib/autograph/core/BUILD b/tensorflow/contrib/autograph/core/BUILD
index 833f9dced8..1873045a92 100644
--- a/tensorflow/contrib/autograph/core/BUILD
+++ b/tensorflow/contrib/autograph/core/BUILD
@@ -19,6 +19,7 @@ py_library(
srcs = [
"config.py",
"converter.py",
+ "errors.py",
"naming.py",
],
srcs_version = "PY2AND3",
@@ -30,6 +31,31 @@ py_library(
],
)
+py_test(
+ name = "errors_test",
+ srcs = ["errors_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":core",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:random_ops",
+ ],
+)
+
+py_test(
+ name = "naming_test",
+ srcs = ["naming_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":core",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
py_library(
name = "test_lib",
srcs = [
@@ -47,13 +73,3 @@ py_library(
"@six_archive//:six",
],
)
-
-py_test(
- name = "naming_test",
- srcs = ["naming_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":core",
- "//tensorflow/python:client_testlib",
- ],
-)
diff --git a/tensorflow/contrib/autograph/core/converter.py b/tensorflow/contrib/autograph/core/converter.py
index 54e6aa0f3b..a93e4a8064 100644
--- a/tensorflow/contrib/autograph/core/converter.py
+++ b/tensorflow/contrib/autograph/core/converter.py
@@ -64,15 +64,29 @@ from __future__ import division
from __future__ import print_function
import collections
+from enum import Enum
+
from tensorflow.contrib.autograph.core import config
from tensorflow.contrib.autograph.core import naming
+from tensorflow.contrib.autograph.pyct import anno
+from tensorflow.contrib.autograph.pyct import ast_util
+from tensorflow.contrib.autograph.pyct import cfg
+from tensorflow.contrib.autograph.pyct import compiler
+from tensorflow.contrib.autograph.pyct import qual_names
from tensorflow.contrib.autograph.pyct import transformer
+from tensorflow.contrib.autograph.pyct.static_analysis import activity
+from tensorflow.contrib.autograph.pyct.static_analysis import live_values
+from tensorflow.contrib.autograph.pyct.static_analysis import liveness
+from tensorflow.contrib.autograph.pyct.static_analysis import reaching_definitions
+from tensorflow.contrib.autograph.pyct.static_analysis import type_info
# TODO(mdan): These contexts can be refactored into first class objects.
# For example, we could define Program and Entity abstractions that hold on
# to the actual entity and have conversion methods.
+# TODO(mdan): Add a test specific to this converter.
+
class ProgramContext(object):
"""ProgramContext keeps track of converting function hierarchies.
@@ -197,6 +211,46 @@ class Base(transformer.Base):
self._used = False
self._ast_depth = 0
+ def get_definition_directive(self, node, directive, arg, default):
+ """Returns the unique directive for a symbol, or a default if none exist.
+
+ See lang/directives.py for details on directives.
+
+ Args:
+ node: ast.AST
+ directive: Callable[..., Any]
+ arg: str
+ default: Any
+
+ Raises:
+ ValueError: if conflicting annotations have been found
+ """
+ defs = anno.getanno(node, anno.Static.ORIG_DEFINITIONS, ())
+ if not defs:
+ return default
+
+ # TODO(mdan): Simplify this.
+ arg_values = []
+ for def_ in defs:
+ if (directive not in def_.directives or
+ arg not in arg not in def_.directives[directive]):
+ continue
+ arg_value = def_.directives[directive][arg]
+ for prev_value in arg_values:
+ if not ast_util.matches(arg_value, prev_value):
+ qn = anno.getanno(node, anno.Basic.QN)
+ raise ValueError('%s has ambiguous annotations for %s(%s): %s, %s' %
+ (qn, directive.__name__, arg,
+ compiler.ast_to_source(arg_value).strip(),
+ compiler.ast_to_source(prev_value).strip()))
+ arg_values.append(arg_value)
+
+ if not arg_values:
+ return default
+
+ arg_value, = arg_values
+ return arg_value
+
def visit(self, node):
if not self._ast_depth:
if self._used:
@@ -208,3 +262,69 @@ class Base(transformer.Base):
return super(Base, self).visit(node)
finally:
self._ast_depth -= 1
+
+
+class AnnotatedDef(reaching_definitions.Definition):
+
+ def __init__(self):
+ super(AnnotatedDef, self).__init__()
+ self.directives = {}
+
+
+class AgAnno(Enum):
+ """Annotation labels specific to AutoGraph. See anno.py."""
+
+ DIRECTIVES = 'User directives associated with the annotated statement.'
+
+ def __repr__(self):
+ return self.name
+
+
+def standard_analysis(node, context, is_initial=False):
+ """Performs a complete static analysis of the given code.
+
+ Args:
+ node: ast.AST
+ context: converter.EntityContext
+ is_initial: bool, whether this is the initial analysis done on the input
+ source code
+
+ Returns:
+ ast.AST, same as node, with the static analysis annotations added
+ """
+ # TODO(mdan): Clear static analysis here.
+ # TODO(mdan): Consider not running all analyses every time.
+ # TODO(mdan): Don't return a node because it's modified by reference.
+ graphs = cfg.build(node)
+ node = qual_names.resolve(node)
+ node = activity.resolve(node, context.info, None)
+ node = reaching_definitions.resolve(node, context.info, graphs, AnnotatedDef)
+ node = liveness.resolve(node, context.info, graphs)
+ node = live_values.resolve(node, context.info, config.PYTHON_LITERALS)
+ node = type_info.resolve(node, context.info)
+ # This second call allows resolving first-order class attributes.
+ node = live_values.resolve(node, context.info, config.PYTHON_LITERALS)
+ if is_initial:
+ anno.dup(
+ node,
+ {
+ anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS,
+ },
+ )
+ return node
+
+
+def apply_(node, context, converter_module):
+ """Applies a converter to an AST.
+
+ Args:
+ node: ast.AST
+ context: converter.EntityContext
+ converter_module: converter.Base
+
+ Returns:
+ ast.AST, the result of applying converter to node
+ """
+ node = standard_analysis(node, context)
+ node = converter_module.transform(node, context)
+ return node
diff --git a/tensorflow/contrib/autograph/core/converter_testing.py b/tensorflow/contrib/autograph/core/converter_testing.py
index 0e46aacc12..c47b70f15c 100644
--- a/tensorflow/contrib/autograph/core/converter_testing.py
+++ b/tensorflow/contrib/autograph/core/converter_testing.py
@@ -25,6 +25,7 @@ from tensorflow.contrib.autograph import operators
from tensorflow.contrib.autograph import utils
from tensorflow.contrib.autograph.core import config
from tensorflow.contrib.autograph.core import converter
+from tensorflow.contrib.autograph.core import errors
from tensorflow.contrib.autograph.pyct import compiler
from tensorflow.contrib.autograph.pyct import parser
from tensorflow.contrib.autograph.pyct import pretty_printer
@@ -89,6 +90,8 @@ class TestCase(test.TestCase):
fake_ag = self.make_fake_mod('fake_ag', converted_call)
fake_ag.__dict__.update(operators.__dict__)
fake_ag.__dict__['utils'] = utils
+ fake_ag.__dict__['rewrite_graph_construction_error'] = (
+ errors.rewrite_graph_construction_error)
result.__dict__['ag__'] = fake_ag
yield result
except Exception: # pylint:disable=broad-except
diff --git a/tensorflow/contrib/autograph/core/errors.py b/tensorflow/contrib/autograph/core/errors.py
new file mode 100644
index 0000000000..e58745337a
--- /dev/null
+++ b/tensorflow/contrib/autograph/core/errors.py
@@ -0,0 +1,272 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Error rewriting logic.
+
+Contains the functions responsible for rewriting tracebacks of errors raised
+in AutoGraph (AG) code to refer to user written code, so that errors only refer
+to the original user code.
+
+When 'user code' is used in comments it refers to the original source code that
+the user wrote and is converting using AutoGraph.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import contextlib
+import logging
+import sys
+import traceback
+
+from tensorflow.contrib.autograph.pyct.origin_info import CodeLocation
+from tensorflow.python.framework import errors_impl
+from tensorflow.python.util import tf_inspect
+
+
+class GraphConstructionError(Exception):
+ """Error for graph construction errors from AutoGraph generated code."""
+
+ def __init__(self, original_error, custom_traceback):
+ self.original_error = original_error
+ self.custom_traceback = custom_traceback
+ super(GraphConstructionError, self).__init__()
+
+ def __str__(self):
+ traceback_str = ''.join(traceback.format_list(self.custom_traceback))
+ return ('Traceback (most recent call last):\n' + traceback_str + '\n' + str(
+ self.original_error) + '\n')
+
+
+class TfRuntimeError(Exception):
+ """Error wrapper for runtime errors raised by AutoGraph generated code."""
+
+ def __init__(self, op_name, op_message, custom_traceback):
+ self.op_name = op_name
+ self.op_message = op_message
+ self.custom_traceback = custom_traceback
+ super(TfRuntimeError, self).__init__()
+
+ def __str__(self):
+ message = '%s\n\nCaused by op %r, defined at:\n' % (self.op_message,
+ self.op_name)
+ return message + ''.join(traceback.format_list(self.custom_traceback))
+
+
+def _rewrite_frame(source_map, cleaned_traceback, stack_frame_indices):
+ """Rewrites the stack frames at the given indices using the given source map.
+
+ Args:
+ source_map: Dict[CodeLocation, OriginInfo], a mapping between the user and
+ AG generated code.
+ cleaned_traceback: List[Tuple[text, text, text, text]], the current
+ traceback.
+ stack_frame_indices: Iterable[Int], frame indices to possibly rewrite if
+ there are matching source mapping keys.
+
+ Returns:
+ None
+ """
+ for frame_index in stack_frame_indices:
+ # (file_path, line number, function name, code)
+ file_path, line_number, _, _ = cleaned_traceback[frame_index]
+ source_map_key = CodeLocation(file_path=file_path, line_number=line_number)
+ found_mapping = source_map_key in source_map
+ if found_mapping:
+ cleaned_traceback[frame_index] = source_map[source_map_key].as_frame()
+
+
+# TODO(znado): Make more robust to name changes in the rewriting logic.
+def _remove_rewrite_frames(tb):
+ """Remove stack frames containing the error rewriting logic."""
+ cleaned_tb = []
+ for f in tb:
+ if 'ag__.rewrite_graph_construction_error' not in f[3]:
+ cleaned_tb.append(f)
+ return cleaned_tb
+
+
+def rewrite_graph_construction_error(source_map):
+ """Rewrites errors raised by non-AG APIs inside AG generated code.
+
+ Meant to be called from the try/except block inside each AutoGraph generated
+ function. Only rewrites the traceback frames corresponding to the function
+ that this is called from. When we raise a GraphConstructionError at the end
+ it is then caught by calling functions, where they can be responsible for
+ rewriting their own frames.
+
+ Args:
+ source_map: Dict[CodeLocation, OriginInfo], a mapping between the user and
+ AG generated code.
+
+ Raises:
+ GraphConstructionError: The rewritten underlying error.
+ Exception: The underlying error, if it could not be rewritten.
+ """
+ error_info = sys.exc_info()
+ _, original_error, e_traceback = error_info
+ assert original_error is not None
+ try:
+ _, _, _, func_name, _, _ = tf_inspect.stack()[1]
+ # The latest function call is added to the beginning of a traceback, but
+ # when rewriting the traceback of multiple function calls the previous
+ # functions' except blocks may have already rewritten their own frames so
+ # we want to copy over all of the previous frames. We may have rewritten
+ # previous frames only if the error is a GraphConstructionError.
+ if isinstance(original_error, GraphConstructionError):
+ cleaned_traceback = traceback.extract_tb(e_traceback)
+ previous_traceback = original_error.custom_traceback
+ cleaned_traceback = [cleaned_traceback[0]] + previous_traceback
+ else:
+ cleaned_traceback = traceback.extract_tb(e_traceback)
+ cleaned_traceback = _remove_rewrite_frames(cleaned_traceback)
+
+ current_frame_indices = []
+ # This code is meant to be called from the try/except block that wraps a
+ # function body. Here we look for all frames that came from the function
+ # that this wraps, look for any matching line numbers in the source
+ # mapping, and then rewrite them if matches are found.
+ for fi, frame in enumerate(cleaned_traceback):
+ _, _, frame_func_name, _ = frame
+ if frame_func_name == func_name:
+ current_frame_indices.append(fi)
+ break
+ if current_frame_indices:
+ _rewrite_frame(source_map, cleaned_traceback, current_frame_indices)
+
+ if isinstance(original_error, GraphConstructionError):
+ original_error.custom_traceback = cleaned_traceback
+ new_error = original_error
+ else:
+ new_error = GraphConstructionError(original_error, cleaned_traceback)
+ except Exception:
+ logging.exception('Error while rewriting AutoGraph error:')
+ raise original_error
+ else:
+ raise new_error
+ finally:
+ # Addresses warning https://docs.python.org/2/library/sys.html#sys.exc_info.
+ del e_traceback
+
+
+def rewrite_tf_runtime_error(error, source_map):
+ """Rewrites TensorFlow runtime errors raised by ops created in AG code.
+
+ Args:
+ error: error_impl.OpError, an TensorFlow error that will have its traceback
+ rewritten.
+ source_map: Dict[CodeLocation, OriginInfo], a mapping between the user and
+ AG generated code.
+
+ Returns:
+ A TfRuntimeError with a traceback rewritten according to the given
+ source mapping.
+ """
+ # Check for cases where we leave a user method and re-enter it in the
+ # traceback. This is done by looking at the function names when the
+ # filenames are from any files the user code is in. If we find a case where
+ # we return to a user method after leaving it then we cut out the frames in
+ # between because we assume this means these in between frames are from
+ # internal AutoGraph code that shouldn't be included.
+ #
+ # An example of this is:
+ #
+ # File "file1.py", line 57, in my_func
+ # ...
+ # File "control_flow_ops.py", line 231, in cond
+ # ...
+ # File "control_flow_ops.py", line 1039, in inner_cond
+ # ...
+ # File "file1.py", line 68, in my_func
+ # ...
+ #
+ # Where we would remove the control_flow_ops.py frames because we re-enter
+ # my_func in file1.py.
+ #
+ # The source map keys are (file_path, line_number) so get the set of all user
+ # file_paths.
+ try:
+ all_user_files = set(k.file_path for k in source_map)
+ cleaned_traceback = []
+ last_user_frame_index = None
+ last_user_user_file_path = None
+ last_user_user_fn_name = None
+ for fi, frame in enumerate(error.op.traceback):
+ frame_file_path, frame_line_number, _, _ = frame
+ src_map_key = CodeLocation(
+ file_path=frame_file_path, line_number=frame_line_number)
+ if frame_file_path in all_user_files:
+ if src_map_key in source_map:
+ original_fn_name = source_map[src_map_key].function_name
+ if (last_user_frame_index is not None and
+ last_user_user_file_path == frame_file_path):
+ if last_user_user_fn_name == original_fn_name:
+ cleaned_traceback = cleaned_traceback[:last_user_frame_index]
+ else:
+ cleaned_traceback = cleaned_traceback[:last_user_frame_index + 1]
+ last_user_user_fn_name = original_fn_name
+ else:
+ last_user_user_fn_name = None
+ last_user_frame_index = fi
+ last_user_user_file_path = frame_file_path
+ cleaned_traceback.append(frame)
+
+ for fi in range(len(cleaned_traceback)):
+ _rewrite_frame(source_map, cleaned_traceback, [fi])
+ op_name = error.op.name
+ op_message = error.message
+ rewritten_error = TfRuntimeError(op_name, op_message, cleaned_traceback)
+ return rewritten_error
+ except Exception: # pylint: disable=broad-except
+ logging.exception('Error while rewriting AutoGraph error:')
+ return error
+
+
+# TODO(znado): Add arg to enable different levels of error rewriting.
+@contextlib.contextmanager
+def improved_errors(converted_function):
+ """Context manager that rewrites runtime errors.
+
+ This context manager will rewrite runtime errors so that their traceback
+ is relative to the original code before conversion.
+
+ Use with the output of to_graph, and wrap the execution of respective ops.
+ Example:
+
+ converted_my_func = ag.to_graph(my_func)
+ ops = converted_my_func(...)
+
+ with ag.improved_errors(converted_my_func):
+ sess.run(ops)
+
+ Args:
+ converted_function: Callable[..., Any], the output of a to_graph call
+
+ Yields:
+ None
+
+ Raises:
+ TfRuntimeError: if any OpError originates in the converted code, it will
+ be wrapped into a TfRuntimeError
+ ValueError: If converted_function is not generated by AutoGraph
+ """
+ if (getattr(converted_function, 'ag_source_map', None) is None or
+ not converted_function.ag_source_map):
+ raise ValueError(
+ 'converted_function must be the result of an autograph.to_graph call')
+ try:
+ yield
+ except errors_impl.OpError as e:
+ raise rewrite_tf_runtime_error(e, converted_function.ag_source_map)
diff --git a/tensorflow/contrib/autograph/core/errors_test.py b/tensorflow/contrib/autograph/core/errors_test.py
new file mode 100644
index 0000000000..7be54563a1
--- /dev/null
+++ b/tensorflow/contrib/autograph/core/errors_test.py
@@ -0,0 +1,116 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for errors module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.autograph.core import errors
+from tensorflow.contrib.autograph.pyct import origin_info
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors as tf_errors
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+from tensorflow.python.util import tf_inspect
+
+
+def zero_div():
+ return array_ops.constant(10, dtype=dtypes.int32) // 0
+
+
+def zero_div_caller():
+ a = zero_div() + 2
+ return a
+
+
+class RuntimeErrorsTest(test.TestCase):
+
+ def setUp(self):
+ self._fake_origin = origin_info.OriginInfo('new file', 'new func', 96, 0,
+ 'print("hello world!")')
+
+ def test_error_replacement(self):
+ _, zero_div_lineno = tf_inspect.getsourcelines(zero_div)
+ src_map = {
+ errors.CodeLocation(
+ file_path=__file__, line_number=zero_div_lineno + 1):
+ self._fake_origin
+ }
+ with self.assertRaises(errors.TfRuntimeError) as cm:
+ z = zero_div_caller()
+ zero_div_caller.ag_source_map = src_map
+ with errors.improved_errors(zero_div_caller):
+ with self.test_session() as sess:
+ sess.run(z)
+ expected = cm.exception
+ current_traceback = expected.custom_traceback
+ for frame in current_traceback:
+ self.assertNotEqual('zero_div', frame[2])
+ self.assertTrue(
+ any(self._fake_origin.as_frame() == frame
+ for frame in current_traceback))
+
+ def test_error_not_found(self):
+ src_map = {
+ errors.CodeLocation(file_path=__file__, line_number=-1):
+ self._fake_origin
+ }
+ with self.assertRaises(errors.TfRuntimeError) as cm:
+ z = zero_div_caller()
+ zero_div_caller.ag_source_map = src_map
+ with errors.improved_errors(zero_div_caller):
+ with self.test_session() as sess:
+ sess.run(z)
+ expected = cm.exception
+ current_traceback = expected.custom_traceback
+ self.assertTrue(any('zero_div' in frame[2] for frame in current_traceback))
+ for frame in current_traceback:
+ self.assertNotEqual(frame, self._fake_origin.as_frame())
+
+ def test_rewriting_error(self):
+ _, zero_div_lineno = tf_inspect.getsourcelines(zero_div)
+ src_map = {
+ errors.CodeLocation(
+ file_path=__file__, line_number=zero_div_lineno + 1):
+ None
+ }
+ with self.assertRaisesRegexp(tf_errors.InvalidArgumentError,
+ 'Integer division by zero'):
+ z = zero_div_caller()
+ zero_div_caller.ag_source_map = src_map
+ with errors.improved_errors(zero_div_caller):
+ with self.test_session() as sess:
+ sess.run(z)
+
+ def test_no_ag_source_map(self):
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'converted_function must be the result of an autograph.to_graph call'):
+ with errors.improved_errors(None):
+ pass
+
+ def test_bad_ag_source_map(self):
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'converted_function must be the result of an autograph.to_graph call'):
+ src_map = None
+ zero_div_caller.ag_source_map = src_map
+ with errors.improved_errors(None):
+ pass
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/autograph/examples/integration_tests/BUILD b/tensorflow/contrib/autograph/examples/integration_tests/BUILD
new file mode 100644
index 0000000000..1368ce244c
--- /dev/null
+++ b/tensorflow/contrib/autograph/examples/integration_tests/BUILD
@@ -0,0 +1,29 @@
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+py_test(
+ name = "keras_test",
+ srcs = [
+ "keras_test.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow:tensorflow_py",
+ ],
+)
diff --git a/tensorflow/contrib/proto/python/kernel_tests/test_case.py b/tensorflow/contrib/autograph/examples/integration_tests/keras_test.py
index b95202c5df..a2fc7c550e 100644
--- a/tensorflow/contrib/proto/python/kernel_tests/test_case.py
+++ b/tensorflow/contrib/autograph/examples/integration_tests/keras_test.py
@@ -1,5 +1,4 @@
-# =============================================================================
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,24 +11,27 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-# =============================================================================
-"""Test case base for testing proto operations."""
+# ==============================================================================
+"""Keras integration tests."""
-# Python3 preparedness imports.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import ctypes as ct
-import os
+import tensorflow as tf
-from tensorflow.python.platform import test
+class MinimalKeras(tf.keras.Model):
-class ProtoOpTestCase(test.TestCase):
+ def call(self, x):
+ return x * 3
- def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
- super(ProtoOpTestCase, self).__init__(methodName)
- lib = os.path.join(os.path.dirname(__file__), 'libtestexample.so')
- if os.path.isfile(lib):
- ct.cdll.LoadLibrary(lib)
+
+class KerasTest(tf.test.TestCase):
+
+ def test_basic(self):
+ MinimalKeras()
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow/contrib/autograph/examples/notebooks/autograph_vs_eager_mnist_benchmark.ipynb b/tensorflow/contrib/autograph/examples/notebooks/autograph_vs_eager_mnist_benchmark.ipynb
new file mode 100644
index 0000000000..a64e266f6a
--- /dev/null
+++ b/tensorflow/contrib/autograph/examples/notebooks/autograph_vs_eager_mnist_benchmark.ipynb
@@ -0,0 +1,664 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "Pa2qpEmoVOGe"
+ },
+ "outputs": [],
+ "source": [
+ "from __future__ import absolute_import\n",
+ "from __future__ import division\n",
+ "from __future__ import print_function\n",
+ "\n",
+ "import os\n",
+ "import time\n",
+ "\n",
+ "import tensorflow as tf\n",
+ "\n",
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "import six\n",
+ "\n",
+ "from tensorflow.contrib import autograph\n",
+ "from tensorflow.contrib.eager.python import tfe\n",
+ "from tensorflow.python.eager import context\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "YfnHJbBOBKae"
+ },
+ "outputs": [],
+ "source": [
+ "import gzip\n",
+ "import shutil\n",
+ "\n",
+ "from six.moves import urllib\n",
+ "\n",
+ "\n",
+ "def download(directory, filename):\n",
+ " filepath = os.path.join(directory, filename)\n",
+ " if tf.gfile.Exists(filepath):\n",
+ " return filepath\n",
+ " if not tf.gfile.Exists(directory):\n",
+ " tf.gfile.MakeDirs(directory)\n",
+ " url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz'\n",
+ " zipped_filepath = filepath + '.gz'\n",
+ " print('Downloading %s to %s' % (url, zipped_filepath))\n",
+ " urllib.request.urlretrieve(url, zipped_filepath)\n",
+ " with gzip.open(zipped_filepath, 'rb') as f_in, open(filepath, 'wb') as f_out:\n",
+ " shutil.copyfileobj(f_in, f_out)\n",
+ " os.remove(zipped_filepath)\n",
+ " return filepath\n",
+ "\n",
+ "\n",
+ "def dataset(directory, images_file, labels_file):\n",
+ " images_file = download(directory, images_file)\n",
+ " labels_file = download(directory, labels_file)\n",
+ "\n",
+ " def decode_image(image):\n",
+ " # Normalize from [0, 255] to [0.0, 1.0]\n",
+ " image = tf.decode_raw(image, tf.uint8)\n",
+ " image = tf.cast(image, tf.float32)\n",
+ " image = tf.reshape(image, [784])\n",
+ " return image / 255.0\n",
+ "\n",
+ " def decode_label(label):\n",
+ " label = tf.decode_raw(label, tf.uint8)\n",
+ " label = tf.reshape(label, [])\n",
+ " return tf.to_int32(label)\n",
+ "\n",
+ " images = tf.data.FixedLengthRecordDataset(\n",
+ " images_file, 28 * 28, header_bytes=16).map(decode_image)\n",
+ " labels = tf.data.FixedLengthRecordDataset(\n",
+ " labels_file, 1, header_bytes=8).map(decode_label)\n",
+ " return tf.data.Dataset.zip((images, labels))\n",
+ "\n",
+ "\n",
+ "def mnist_train(directory):\n",
+ " return dataset(directory, 'train-images-idx3-ubyte',\n",
+ " 'train-labels-idx1-ubyte')\n",
+ "\n",
+ "def mnist_test(directory):\n",
+ " return dataset(directory, 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte')\n",
+ "\n",
+ "def setup_mnist_data(is_training, hp, batch_size):\n",
+ " if is_training:\n",
+ " ds = mnist_train('/tmp/autograph_mnist_data')\n",
+ " ds = ds.cache()\n",
+ " ds = ds.shuffle(batch_size * 10)\n",
+ " else:\n",
+ " ds = mnist_test('/tmp/autograph_mnist_data')\n",
+ " ds = ds.cache()\n",
+ " ds = ds.repeat()\n",
+ " ds = ds.batch(batch_size)\n",
+ " return ds\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "x_MU13boiok2"
+ },
+ "outputs": [],
+ "source": [
+ "def mlp_model(input_shape):\n",
+ " model = tf.keras.Sequential((\n",
+ " tf.keras.layers.Dense(100, activation='relu', input_shape=input_shape),\n",
+ " tf.keras.layers.Dense(100, activation='relu'),\n",
+ " tf.keras.layers.Dense(10, activation='softmax')))\n",
+ " model.build()\n",
+ " return model\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "kfZk9EFZ5TeQ"
+ },
+ "outputs": [],
+ "source": [
+ "# Test-only parameters. Test checks successful completion not correctness. \n",
+ "burn_ins = 1\n",
+ "trials = 1\n",
+ "max_steps = 2"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "gWXV8WHn43iZ"
+ },
+ "outputs": [],
+ "source": [
+ "#@test {\"skip\": true} \n",
+ "burn_ins = 3\n",
+ "trials = 10\n",
+ "max_steps = 500"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "DXt4GoTxtvn2"
+ },
+ "source": [
+ "# Autograph"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "W51sfbONiz_5"
+ },
+ "outputs": [],
+ "source": [
+ "def predict(m, x, y):\n",
+ " y_p = m(x)\n",
+ " losses = tf.keras.losses.categorical_crossentropy(y, y_p)\n",
+ " l = tf.reduce_mean(losses)\n",
+ " accuracies = tf.keras.metrics.categorical_accuracy(y, y_p)\n",
+ " accuracy = tf.reduce_mean(accuracies)\n",
+ " return l, accuracy\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "CsAD0ajbi9iZ"
+ },
+ "outputs": [],
+ "source": [
+ "def fit(m, x, y, opt):\n",
+ " l, accuracy = predict(m, x, y)\n",
+ " opt.minimize(l)\n",
+ " return l, accuracy\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "RVw57HdTjPzi"
+ },
+ "outputs": [],
+ "source": [
+ "def get_next_batch(ds):\n",
+ " itr = ds.make_one_shot_iterator()\n",
+ " image, label = itr.get_next()\n",
+ " x = tf.to_float(tf.reshape(image, (-1, 28 * 28)))\n",
+ " y = tf.one_hot(tf.squeeze(label), 10)\n",
+ " return x, y\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "UUI0566FjZPx"
+ },
+ "outputs": [],
+ "source": [
+ "def train(train_ds, test_ds, hp):\n",
+ " m = mlp_model((28 * 28,))\n",
+ " opt = tf.train.MomentumOptimizer(hp.learning_rate, 0.9)\n",
+ " train_losses = []\n",
+ " train_losses = autograph.utils.set_element_type(train_losses, tf.float32)\n",
+ " test_losses = []\n",
+ " test_losses = autograph.utils.set_element_type(test_losses, tf.float32)\n",
+ " train_accuracies = []\n",
+ " train_accuracies = autograph.utils.set_element_type(train_accuracies,\n",
+ " tf.float32)\n",
+ " test_accuracies = []\n",
+ " test_accuracies = autograph.utils.set_element_type(test_accuracies,\n",
+ " tf.float32)\n",
+ " i = tf.constant(0)\n",
+ " while i \u003c hp.max_steps:\n",
+ " train_x, train_y = get_next_batch(train_ds)\n",
+ " test_x, test_y = get_next_batch(test_ds)\n",
+ " step_train_loss, step_train_accuracy = fit(m, train_x, train_y, opt)\n",
+ " step_test_loss, step_test_accuracy = predict(m, test_x, test_y)\n",
+ "\n",
+ " train_losses.append(step_train_loss)\n",
+ " test_losses.append(step_test_loss)\n",
+ " train_accuracies.append(step_train_accuracy)\n",
+ " test_accuracies.append(step_test_accuracy)\n",
+ " i += 1\n",
+ " return (autograph.stack(train_losses), autograph.stack(test_losses), autograph.stack(train_accuracies),\n",
+ " autograph.stack(test_accuracies))\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ },
+ "height": 789
+ },
+ "colab_type": "code",
+ "executionInfo": {
+ "elapsed": 11529,
+ "status": "ok",
+ "timestamp": 1531163743912,
+ "user": {
+ "displayName": "",
+ "photoUrl": "",
+ "userId": ""
+ },
+ "user_tz": 240
+ },
+ "id": "K1m8TwOKjdNd",
+ "outputId": "59db8f19-23a5-413a-e9d0-fb756b0e4757"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Duration: 0.592790126801\n",
+ "Duration: 0.594069957733\n",
+ "Duration: 0.591835975647\n",
+ "Duration: 0.592386007309\n",
+ "Duration: 0.595040082932\n",
+ "Duration: 0.594245910645\n",
+ "Duration: 0.624264001846\n",
+ "Duration: 0.6021900177\n",
+ "Duration: 0.592960119247\n",
+ "Duration: 0.599496841431\n",
+ "Mean duration: 0.597927904129 +/- 0.0093268291102\n"
+ ]
+ },
+ {
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYkAAAEcCAYAAAAydkhNAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3Xd8FGX+wPHPbMum90IKvQSQ3jtSbYCAqHee9TxPT0VF\njztRT+9UzvMOsdzPUxTO3gURsYsgTRBFmvROQkJ63T7z+2OS3Wx2EwIkBC7f9+vFi+zO7Mwzz84+\n33nKPKNomqYhhBBCBGFo7gQIIYQ4d0mQEEIIUScJEkIIIeokQUIIIUSdJEgIIYSokwQJIYQQdZIg\nIYQQok4SJISow6ZNm7j44oubOxknlZWVRWZmJqqqNndSxP8gCRLilI0ZM4YePXpQXFzs9/6UKVPI\nzMwkOzsbgD//+c9kZmaybds27zpHjhwhMzPT+/raa6/lgw8+8L5+4YUXGDt2LH379mX06NHMmjUL\ngMsuu4y+ffvSt29funXrRs+ePenTpw99+/ZlwYIFAWn897//zezZs8/oOPv3789nn312Sp958cUX\nmT9/Phs3bmTUqFFntP9qtfMoGEVRGmVfQtRmau4EiPNTeno6y5cv55prrgFgz549OBwOv8JKURRi\nYmJ4+umnWbhwod/7wSxZsoRly5bx6quvkp6eTkFBAStWrADgk08+8a537bXXcvnllzN9+vQzOgZN\n0xq9cF21ahX33XcfLpdLCm7xP0FqEuK0TJkyhSVLlnhfL1myhKlTpwasN3XqVHbv3s2mTZtOus3t\n27czfPhw0tPTAYiPj2fGjBlB161vNpnVq1fzwgsv8Omnn9KnTx8uv/xyQA8u8+fP51e/+hW9e/fm\n2LFjLF68mEsuuYS+ffsyfvx43n33Xe92atcGxowZw6JFi5g8eTIDBgxg1qxZOJ1O7/LS0lIOHz5M\nt27duOWWWzhx4oS3tpOXl4emaSxYsIDx48czePBg7rnnHkpLSwFwOp388Y9/ZNCgQQwYMIAZM2ZQ\nWFjI/Pnz+fHHH3n00Ufp27cvjz322Enz8cSJE9x2220MGjSIiRMn8v7773uXbd26lenTp9OvXz+G\nDx/OP/7xj3r3D1BeXs4DDzzA8OHDGTVqFE8//bQ3/48cOcK1115L//79GTJkiLfmJ/53SE1CnJZe\nvXqxdOlSDhw4QNu2bfn888956623mD9/vt96VquVW2+9laeeeoq33nrrpNt8/PHHSUpKYtCgQXTr\n1g2D4dSvY0aMGMGtt97KkSNHePLJJ/2WLVu2jJdeeol27dqhqirx8fEsWLCA9PR0Nm3axM0330zP\nnj3p2rUrEFjr+fzzz1m0aBEWi4Wrr76aJUuWcNVVVwGwZs0aBg8ejNVq5aWXXmL27NmsXLnS+9lX\nXnmFFStW8OabbxIbG8tjjz3GX//6V+bNm8eSJUsoLy9n9erVmM1mdu7cSUhICPfccw8//fQTU6ZM\n4YorrmjQ8c+aNYsuXbrw7LPPsn//fm688UYyMjIYPHgwc+fO5frrr2fy5MnYbDb27t0LUOf+AWbP\nnk1SUhLffPMNFRUV3HrrraSmpnLllVfyzDPPMHz4cF5//XWcTifbt28/5e9LnNukJiFO25QpU/jo\no49Yu3Yt7du3JykpKeh6V155JcePH2f16tX1bm/y5Mk89NBDrF27lmuvvZahQ4cG7W84E1OnTqVD\nhw4YDAZMJhOjRo3y1lz69+/PsGHD6q31XHfddSQkJBAVFcWFF17Izp07vctWrlxZbz/Ee++9x913\n301SUhJms5nbb7+dL774AlVVMZlMFBcXc/DgQRRFoVu3boSHh5/y8R0/fpzNmzdz3333YTabyczM\nZMaMGSxduhQAk8nEkSNHKCoqIjQ0lJ49e3rfD7b/goICVq9ezZw5cwgJCSEuLo7rr7+e5cuXez+X\nlZVFbm4uFouFvn37nnKaxblNahLitE2ePJnf/OY3HDt2jClTptS5nsVi4Q9/+APPPPMM8+bNq3eb\nl112GZdddhkej4evv/6ae++9l+7duzNs2LBGSXNKSorf61WrVvH8889z6NAhVFXFbrfTpUuXOj8f\nHx/v/Ts0NJS8vDxAb/5at24d999/f52fzc7O5o477vDWjjRNw2QykZ+fz5QpU8jJyWHWrFmUlZUx\nadIkZs2ahdFoPKXjy8vLIzo6mtDQUO97qamp7NixA4C5c+fyzDPPcPHFF5ORkcHtt9/O6NGjA/Y/\nefJk7rnnHrKysnC73QwfPtybZk3TaNWqFaDXMp5++mmuuOIKYmJiuOGGG864r0icWyRIiNOWmppK\nWloa3333HXPnzq133WnTpvHyyy/z1VdfNWjbRqORiRMnsmDBAvbu3dtoQaJm85HT6eSuu+7in//8\nJ2PHjsVgMHD77bfX299Rl23btpGenk5sbGzAfqq1atWKuXPn0qdPn6DbuP3227n99tvJzs7md7/7\nHe3bt2f69Omn1AGelJRESUkJlZWVhIWFAXrtorqW17p1a2+g/uKLL5g5cyYbN27EarUG7L9du3aM\nHDmSkJAQNmzYEDQd8fHxPProowD8+OOP3HjjjQwcOJCMjIwGp1mc26S5SZyRuXPn8uqrr2K1Wutd\nz2g0cscdd/DSSy/Vuc6SJUtYtWoVFRUVaJrGqlWr2L9/v7dJ5FQkJCSQlZVVb4HvcrlwuVzExsZi\nMBhYtWoVa9euPeV9gd7UNHLkSO/r+Ph4iouLKS8v97531VVX8dRTT3mHCBcWFvLNN98AsGHDBvbs\n2YOqqoSFhWEymby1iISEBI4ePVrv/quPMyUlhT59+vDUU0/hdDrZtWsXH3zwAZMnTwbg448/9nZI\nR0ZGoigKBoOhzv0nJiYybNgw5s6dS3l5OZqmcfToUX744QdA76PJzc0FICoqCoPBcFr9SOLc1aQ1\niTlz5rBy5Uri4+NZtmyZ9/3XX3+dN998E7PZzKhRo7jvvvuaMhmikdW8oqx9xVjfVe9ll13GggUL\nKCsrC7p+REQEL7zwAgcOHMDj8ZCamsojjzwS0M7dkCvriy66iI8//phBgwaRnp7O4sWLAz4XHh7O\nAw88wF133YXL5eLCCy9k7NixdW6zvv2uWrWKv/3tb97X7du359JLL2Xs2LFomsby5cu5/vrrAbjp\nppvIy8sjPj6eiy++mLFjx5Kfn8/DDz9Mbm4u4eHhXHLJJd6C/brrruNPf/oT77zzDpMnT+aBBx6o\nN23z5s3j4YcfZsSIEURHR3PXXXcxZMgQQB/59cQTT2C320lLS2P+/PlYLJZ69/+Pf/yDf/3rX1x6\n6aVUVlaSkZHBzTffDOg1qOoAkpCQwAMPPEBaWlq93404vyhN+WS6TZs2ER4ezuzZs71BYsOGDbz4\n4ossWLAAk8lEYWEhcXFxTZUEIZpcQUEBl19++Uk75oU4HzVpvbB///5ERUX5vff222/zu9/9DpNJ\nr8RIgBDnu7Kysno7rIU4n531xsNDhw6xadMmrrzySq699lq/KRuEOB+1bduWSy65pLmTIUSTOOuj\nmzweD6Wlpbz33nts3bqVu+++29t5J4QQ4txy1msSKSkpTJgwAYCePXtiMBgoKio66eeasOtECCFE\nHZq8JlG7cB83bhzr169nwIABHDx4ELfb7R1bXh9FUcjLKzvpei1BYmKk5EUVyQsfyQsfyQufxMTI\nM/p8kwaJe++9lw0bNlBcXMzo0aO58847mT59Ovfffz+TJk3CbDZ7JxgTQghx7mnSIbCNTa4MdHKV\n5CN54SN54SN54XOmNQm5NVIIIUSdJEgIIYSokwQJIYQQdZIgIYQQok4SJM4zTo+LVza/T7GjpFG3\nuyn3Z37I2dyo22xqdredV356jzJn+clXPgUbjv/ITye2Nuo2m1qlq5JXfnqPSldlo253bfYGtubt\naNRtNrUyZzmvbH4fu9veqNv97tg6dhTsbtRtng8kSJxndhXu4dM9K1ifffJnRp+K/+54i1d+ebtR\nt9nUtufv5NO937Ix56dG3e5rO99l4fY3GnWbTe3nvO18uvdbfjyxpdG2qWoqb+36kBe3vdpo2zwb\nfszdwqd7VrDlFIJbeXk5S5Z8UOdyl8fFu3s+4vktCwOWzZ59NxUVDb9QWbRoAe+8c/6cXxIkzjO2\nqqujEmdpo23To3q8f2eX57CnaD+bT2zD5razLf+XRttPY7N5qvLC0Xh5YXc7vH8fLctmT9F+tuRt\np9JVyY6CXY22n8bmPS8aMS8qatRKjpQeY0/Rfrbl/0K5s4KdhXsabT+N7XTyoqyslCVL3g+6TFVV\nSp2+4bQHS46wt2g/2/N3UuIo47d/voPw8IgzS/Q5TJ5Md55xePRCrDELgzKX7yro8Y1PBSy/reeN\nXJDQtdH211iqC/STBUxN0xr8dLfSGtt64oenA5bf0/c2Osa0O4VUnh32Jjgvam7rH5ueDVh+/4C7\nSY9MbbT9NRbvb+QULqReeOHfZGdncdNN19C//yCGDBnGf//7EvHxCezbt4e//vtJDr69FVeJg9+5\nryVhSAbx/fRj3/nUOhYufJ1Qzcp9982kR4/ebN++hcTEZJ54Yh4Wi6XO/e7du5t//esJHA4HaWlp\n3H//w0RERPD++++wdOliTCYTbdu245FHHmfz5h959tl5Veeywv/930t+j6ltKhIkzpCmaVS6bYSb\nw87K/uwN+AFUuioJNYX6P6rT40QDQoyBJ+zJCpYiR3GD0na28sLudLN8/WGUVjag/vS//90vfLbu\nOM/eNZKIUDMADo8TBbAEzYv6b8CqeUVZH1VTsbvthDVSXmiaRoXd7T2G2nwFY93pq3BVBnw3drcD\ng2LAYgzc7skK2XJXxcmSDVTnhYMwc8MKtPdW7OOHXScatG4wFW4LDvcoVmwzs/HrdQAMyEziyjEd\nfevUyovbbruTAwf38+LLr2I2mNi8+Ud27vyF119/j+jYBP67bgWtL++KMdSM6vKw98VNRHdLxBRq\nBgUq3TZCjVaOHTvKX//6d/70pwe4/e57eP29pfz2NzP88sKtur2vH3vsEWbN+hO9evVm4cIX+e9/\nF3DnnbN4881X+eCDZZhMJm9T1jvvvMG99/6ZCy7oid1urzf4NCZpbjpDa7K/Z/bqR065KeLjNQf5\neV/+Ke/P4XECUFpHYXao9Ah/XP0Inx9a4ff+X7//J/ev+Rvf/nSM1Vuy/ZadLEgYGnCafL8jh/9b\ns5TZqx9hX/HBk65/Jj5YuZ/l6w+zeV8OUHfBvadoHyvdr2BMPszOw75JJB9c+zgPrw8+HczJCsaa\nP/D6fHrwK/64+hGOlB1r0Pons2brcWY+s7rOc8ZRVasqDfJdbj9YwIsrVzB79SOsy97ot+yPqx9m\nbpDaI5w8YFYHppNZvO8T/rj6YXIqcimtcFJS3rDPna7qSSTUOiaT2JSzmdmrH+HH3J+976mayvGK\nHP616d/e97p1605KSgqb9+ax+VAWeeuPsvv5jex96UdcpQ6cBbaqHYLNpTdxpbRKJTw2FVXTyCqP\n5ONv/fuI3t71IV8c/pZKVyUVFeVUVJTTq1dvAC666FJ+/lkfPNKxYyceeeQBvvzyMwwG/TG2PXr0\n4tlnn+KDD96hrKz0rD0mVmoSp2HBxzvYe6yEf/5hqLcw/jF3C93jMxv0+bJKJx+t0QvSRX8ec0r7\ndtRoYlE1lez8SkLMRr79KYvIcDPuhJ0AfHLwCy5u53sUZ/VoqDdW/ozmDGNEr1Q278nj25+z6DOo\n/hExFQ0YMbNg2S9Y+3+PYoCteTvqbJL5YuMRsvMruOHizAY3AdWWX2KvSpcNrHUHuR9z9R+oOW0/\n+SU27/uV7uoaSBnRIf5TFgQrZGtqSF4AfHZIn/5+V+FeWkemN+gz9fl84xEA3vxyN+1aRREd7n8V\nWV3DLA4S5J56dwuWjpsxxsFnB74lM6IncVFW3KobVVPJsxVQ6bIFXOmXniRgNjQvvj26BoADJYd5\n+TW9M7m+8/7KMR39rvpP1YKtr7Ilfwfx1jj+NvTPActXZ30PwMpj6+iXrBfQDrd+8XWsPNt7IWa1\nWvn+lxxe/mQnDsMhyg8W0emW/hhMBvYt+gnV7evLq3BV6udihYc5C77n4RsGgKKA6vFr7lx3XH82\neHFVAK5rVqR//vMZfv75J9asWcUrr7zMG2+8z29+cwNDh45g/fo1/P73N/L008/TunWb086nhpKa\nxClwq26e2byAH078SEGpHYfT472aCjGGeNf75sh3vLj1Vb8TQFU13B6VnIpc/rX5GZRQ31Xatz8d\n4+/vr+TR7+dxsOQIa7cdx+HST8D/W7KNj1Yf8K5bXRiomkqFq5K/vPYdf1n9L778ZQuLt6zh88O+\nGsTS/Z+xL6uEQ8d9P3Zr7+8wZeyiuNzBc4u3sf1AIUeK6q/RbMr9mbkb5/sNNV26/zNe/+U9lu7/\njEXb3gJAMah6Xph8efHujs94euN/ASgstfPuin2s3nqcrPwKjpZl8/iGpzhRmee3v6Nl2Ty2YR7Z\n5Tne93KLKtm63z+dmuL25ond7aDAVsjcjfM5UnqM9cc3sSZ7g54uk4tNRWsB/5rAnLWPsmz/537b\nDFbI1rT++A/8bd1T/HI01/veB3s+5u3di/lw7zLe3Ok/QsZa47xYduALXv/lvaDbPVhymMc3PEWB\nzVfj+WTdIf65dAWPbZiH26TnfUGpg4cXbQz4fHXBVu6swKN6yK04weMbniK7PAdj4hGMcXp6C50F\nzPnoTUBvgqz2x9UP89lB/+e6nKyG+d2xdTy56Tm/zv63d33Ih3uX8fbuxby3Z6nf+jWbOt/asZS3\ndy8Out09RfuYu3F+wP73Fu3n8Q1PUWiv/9EC3tq2sxRVVckqP151nuWz4sh37CvRL9AOlBxi5VH9\nvDBZTagO/Td376qHvKPElq87DIDqKccYasZgMmDPq6DymH/aVmWt48Utr+Fw6efXL4cKMcYfxxBV\nwKJt7/DRvk/91jcaDISHRxAVFcX8T57jw73L+OKLT+ndW3+ee25uDn369GPUlRPILTpBXkk+WVnH\naNeuPddccz2p7TKY99WzJ63tNQapSZyCnIoT7Cnah6U92PLTKat0en8gZoOvTXfxvk8Avc020hKB\nzeHmLws3UlLhoP3wneQ78jC3dePcOQiA17/cQ0i39RgqS3hr2yfsX5OJ060yuFsyP+7O48fdeVw6\npC1mk4Gf9h2HcH0/eeXFmNP2Ywgvw9JpM4rZ6ZfeLw9/y9KNIaCohA7wvW9udYgdBwu9r/fmnIB6\nmjezKo4DekHWM7G7d9t+DON8f6tGlq8/xMSBrfkuV1+v0mnn281ZvjxadYATSV9S5D7Bkn2f8vue\n13uXvbztNfLthXx5+Ftu6P4rAF77fDe7jhTx9J3Dvc0ImsFX4Jc6S1m6/3Oyyo/z6s53yanwFeIA\nOZaf0bRf+RWMAJ8fXsGkDhfhcquYjArlzvrb2bPK9bx4atkqXv7DlWiaxrfH1vitM639ZO/fRsXo\n21dV7eKarlfgdmuYTAYMVVeYL29/g2JHCV8c/oZfZ16h59F3B7D2/Rqlwo0lej/kdtKPtcL/GMBX\nw9TQKHOV89buD8muyOHdPUuwtPNv/jOm70ZVNW9hWu2Tg18wvvWFAJiMBspO0udwtFxvtnzmkzUc\nO2TmsVsGegNztRmdfHlRZvPtb22uXjhf3XlqQI3y+S2LcKluvj26hovbTCTErOfh05tfBGD98U1c\n2m48ALuPFHGi2MaInr4O9OoLKZfqZua/vyWmzyaKXAUs/Gkxx5z7/Pb1/t6lDE8dginUQljraHb/\n3wYiO8VT1qmQeHcYefl6HkR2TKR4j8bu5zcSkhBGeEaNxzIr+nmhOt1g0M/N3ccKMISXgtnDT/l6\nE1JbbaD3I9UjCufMeYTbH7oV1a1iNbbnuXlP4Ha7+dvfHqKiooKs8mzihqSyrWwnP727nq9WriUq\nwoojvpLWrbuy5siPXNppNJt2nWDBsh08evMgkmMbt09QgsQpcKo1flRmOw98thBTsn5SHMwug05Q\nXKO99VD+CY45NnAiK4SC0qorPbu+XDHbMbfZQZG9n/66qmbhcIIhooifStYRltMdU9pe3Fkd2X20\niHatoqh02akudnLKCqGqoNRUI8Eab4xJhzHGBnYCHsotxZSxG09eGnllpZjiT378Jc5SPKqHd/d8\nFLDM0vlH79/rtmdxfKf/8v15eazKWYkhIhq1PI6f9+UTYq3EEAZF9iLe3r2YSe0mEm4OI9+uB7Cs\nvEp2F+7jYMkR9pcWYmxVRG5RX28h6cHlS5ujzFerMwSPeB/t/oYDFYFDNw8cL+KJFW9zUYcRVIQ2\nrAlFMTsoLKvk2R9eC1h2z/JnMMbofwdrt88qLuDR5R8S7kjjst59GNsv3VvDyass4I1fPmRC+jhA\nQzFVfb9oGKLzUKwVaPZw5q/8kLEZo+nZQf/i7DX2U+Io9QaNEEMIwXy460t2lwQOb571/CqsbfZx\n54gplDkaNvZ/b24uqjOR//z8SsCy/+54y/v3m9/8ArT2W17hquTrI6vol9yLjMg0QC/cAbZnH+az\nH/7Db/tNZc+RUm+7h8vjYmveDoocJbz31TGchnKslssYkJkUkBc2tQJPeQWGEDiUXYkpITD9d727\nEEvCCdpc0d3v/ZFdb+azQ1/iyu6AMcRD+2t7BT3+rvcM1f8IM9P+sutRy9zsMn1B0jD/Y/3Pj29h\nSoKUC9sxrKsejDt27ESnW/rrad00jle2f8rvYy7i+edfBuD2FbMBOFh0lMjR7chQ+4ECof2/AuBY\nYREbsn7m5bVb0KLMvPnTl8wae3nQdJ4uCRKnoGYbrDltH6YkX6dkYUUFJ4pt/PmF9YRWXTA89+UK\nzK33YHEkAnowsLtcYAKD1YbBepR3dn0EhlQUo95UU+YuIaTbIQ6ocOLQLsxpZaglCRzMLtVHthh9\n7aB5FUUo1a89vivWmixtdwZ9/5eyzZhbHcQYdxzNHt6g4/9h3xGiLJGsrXW1CGCM8jUBFJTrV18H\nskuh6nlSH21ZiydxN0nx6VT8kkxZpQsUPcAeLc/maHk2HtXD5R18z4o+WnKCZ39eAICSEYrZamNP\n3lGKyvRCwK25vO2lr3+zBWua/v3UbO6q6evsL4O+P3/VYsytDvFVXgGt4qKCrlObYnHw+sZvyVMO\nBCwzxviaz9b+coztG7fSKsF3dffNgQ2YUw9gK6zgza/CGd6jlXfZnuL9wH427shHsfj6MjymckK6\n6PtS7WHsUyvZvszC87dfQojZSIXTd3dxsaOU/LJyUPxruDWtzAn+yGBH/C94Ig7z+MpFGMwulAZc\nlCoWB8bEoxy1B+ZFzZv7vOcqqve99cd/4KsjKylxljI67lKS43z9IsfdBzElw8vfL8dTmIK1p/7+\nCVs+Xx1ZqW+pVSgWq421u7oxIDOJVT9nkVdS5i3ZFIsDxVhV4/TUUdyl7CWwbgafHfsUU0o2isUB\nJleQNQIpZgemlEN+v4dqNcuL6lpczYsIU9IR8i17WHUskl9lTvf77M8F+gwAxqSuqGUx3vcLHYW8\ntns9lnagOqzsx065czwRlob9phvC+MgjjzzSaFtrYpWVwb7Ks+dAyWG25usdb4Zw/zZJm1bGVxty\nUEwuTAl6NVwJsaGYXHg0D1HGeLSoXNyR2WDw/UgqXXYcDg1jjN7e7tZcKFVVVqemH6/mslBuc7P8\nl42Ya5xoRwrz8eBGsThQ7WEYQho+DYGt3IwSVqIX1KoBxXLyvM0vUNlS9JNfM08wSkglmtuMS7Hh\njDwKQKm7BMXkIizESNu4ZE549BpO9bGCHkC//fkITqte81EMqrdgqb6iLiyAwjIbhpgTmBJ9zVfl\n7grsaiWqwYmzPAynoeFj5F12M4awcjTViN3lRDOePC80l4Vi6x4weOpdr9RVQk6ek30ncjDF630s\n+bZCPIoTNAXNGcIvBbsoVI56gyaAR1XB6PYWNqqmevNAqSqwNEcoTs3OcechNuf7RuqU2MvJrSxA\nMahY3LGUeAoanBeax4jBagO3BUxO7z7rz4sQTMlH/L7LYBSLHTx6Xlf3kezNO45mcIJqYOmqY3y+\ncxPG6Nrp1VDMToyR+lBsg6J4h996a1r2cBIS4J0f1uKO9J0XitGNIVyvpau2CO/fDaPvV3OEYbBW\nevvc6mM1hKElHORkYzKK7MWEmUM5VHrUOzKyurywGq2YFCM7Cnazq2hvwGcNFjuGCH0gSnGpSw9i\nVOWFAqnhKWSX5fHV7p8IDzWSEZ9yCsccSB46dAq+PrKKJfuWn/X9qg4ritlR74/Q6orHbm54YeAp\niccYXYBqCwODekoBpjmp9lCUEDuKUndeeMpivAVKQ1Sv7ymLwWCtQDE37KrxbNI0Agoe1R6GElJZ\nb4HkKY0NelVbF9UWjiG0Ak9REoaofG8N95yiKX4BFcDgCkc119+Hcqp5obnMKGYX7rw0jAlZJy34\nm4OmKgHlQrgSQ4Wmn/9RxlhevmLuGe1DRjedgoYO+atLvJKBY08f7FtH0J0JeAqT/ZarDmvQzxlC\n7CgGDdVWdxXSHO7f9q1WBD6NynW8LVHGOAAUa9WxeMz1Bgj7z6Ow/zwKrUZzln3bUFxHO3tfR+YN\nDPbRevVI6EYXzzguib0ex56+eIoS/ZZrzuBNRgarDUXRSAlLqnPbhhD/7ynCEBOwjiu7Pao9tGqb\nVeurxnoDhG3zaD0vavwm7VuH48rq4H2tHAvebl0fd0GKngc7RhJybBCeEv8OIs0ZErSAMlj1AFHf\neWGw+heccdbA58m7jnX05rdiqToXFK3eAGH/eRT2LSP939s6HFe2b+iz89Cp36Xvzk/Fsacv9q3D\n6eIZj6fUP72aMyQgQADeAFH9nQYTHu3/G9GcgX1XriNd0NxmvfCtOhcUs7PeAGH/eSSGPaP93vvL\n4D8yMm2I97Vy/DTy4kS6Ny8ce/qilvs3haoOa9ALx+oA4cpqT6ZnwinvtzYJEqegooF3mNZU86Tt\nndaRSHcGwzp15IL4bqiV/gW5pyBwigPV7msUbu0aVOd+as+EqlZEB6wzoFM6mQn6j9gQot8rUF1t\nrYvmDNWX+CEEAAAgAElEQVT/ufUflKcwGc0WhafYV6ifONywdvwQzXe87aPbMHP8BC7t050h6T1R\nbf5z37jz0wI+H2uJ8/49o/OUOvdTu+ksxRqYr5rL4s2j6lFhgc0ctbisaM5QcJur0tgKzR7BbWP1\nTkjNY6TyRFx9W/AyeXzHO65rT4a36Y2zIozi7FjC8Q9qwfIiMdQXSK7ObHheVHcO16S5Qogx6kG3\nunmvZr9KMJozFM0Rhqbqpaf7RDqaPQK1TC/UFbcVQ3lyfZvwqvkbUUvjUIuT0OwRzBw/HrPH/zcS\nLC9q/kZcR+oujO2a/8WDWhl43mquENTKSL/C9+R5EUZFse+i5sKM4SSHJdI+uq2+H1s4tvyGnRf+\neRFPa2sHnrxxAjFqBqrD/2LAk19/eeHOaUduzplXf5o0SMyZM4ehQ4cyadKkgGULFy4kMzOT4uKG\nNws0txK7HiSc+3riPHBBwHJ3rm80g0kxMrHNGK7v6RtpkBgew5O3DuWGSzIZkJmEBf9eQa0yArVC\nP3E1txlrcRfu7nsLIYVdcB7O5OpBQ7ilx3WMib7K70q+pgszhhObeyFakE66CIsVax2dusE4dumj\nLtISw/l1l2l0CumD65g+DFOzReI82B37lhGgBu7LneO7ycegmXBldaBnuO/KKtri+4EO79kKzeWf\nLrU8mhBVL8Q1p4XBCcO4p9/vmdhmDFd1nkpmXCda20Zh3zGE0OIuQdPvym6HfcdgIq1Bri5VI3ER\nDe/cc+z0jSF2HuhBO1MvrulxGXfP6EnvVp0w5/TAsW243pZfy7jWo3wvPGZcWR3I0Hp73+qYlEy/\nLr6g2ybBv1allsZ5a5kxIdFc1HYsd/a+hQltLuSazCsY1aE3N3b7FTd1vhV3bkbQ9F/Udix/GjAz\n6LQsqEZaxTb8OciOX3wXK7+94BouTB+BO1uvTd13yQRMORdwRer1PH3b2IDPTmzju4lOc1kYlz6W\nXtG+mmiUJZKkmFCuv0j/Tnu39Q8KanECmks/BtVhxZXVAeeuAbiy2+Hc34P/u2kGkzMux75tGO78\nVgRzabvx3D/gbvp2DGyr1zzGOgeBBGPfMbjqLwXnvl70iBjo/b77JffCmNMdx85BaM7AVoKaNVDV\nYcV1tBOeE74y5OaL+vDAdf2Ij7bStXVsQO26Z2J3NLf+21PtYSQ5ejHAPEnPi329MCsh5BbZOFNN\nGiSmTZvGwoWBU+vm5OSwbt06UlPPvcnB6lNaNSTQU5jCr/sF/gA8RUmYNP1kGNyqP5M7XESneF9h\nGW2Jwlw1Nj40xMTYnv53lY7r0554RV/ffbwdtw28gi4pqTw44RpmjppCu1bR9Eq8gFsvGs2Q1L4B\n+7cYLUzpcAndEzuCGniiR4eFBi8k6qCWJtA5PZpHfzuIEe17cfewX9Exwfej9eRloDmCF7Ttw3wF\n99jWI7ip7xTGd/cNMYwO8QWJLq1juWxArbvV3RZaW/Uf0YiUkVzbcwrxobFM7nARI9P1YHPjsDG0\nj23NNf0D7941ahbcxzqhVcQQExaYxoGdU+nRNjHg/eAU1DLflbtakkTfiFGM7NqJnh0SUBQFa2lH\nvZYRZCBy/2RfQBiRPJJ0T18m9KiZF5H0aB/PA9f2487pPRjQwX/opOa2oBbrV/qXthvPpPYTiQ+N\nZUqHixmaqhew/VP6MDyzG+68wLu7oy1RXNZuAq0j0/1u7vPyGIkNb9jYejNW1HJfE1C/lJ5c0XlS\n1bFDZps4nvn1dYy+oAOh5sB9DWrVz/t3iqMPUztPJMnqazq8uF9nnrh1CKN66+dZp2T/gtyshGK2\n6e+5j3bGndUJzRmK+1gXopztCbEYubDtIDRbJJ4TgQEzMTSei9uOIz0ylVCTr+D22N3kbzwGqpGo\nBk6al2hNJC1Mz++iA2tw5SUwIX0CMSH6xY1BMeDJbcvR1Yvo0yrwvHDXSJ/7aGfcxzv41ajbJSRi\nrJp646qxnRjS2f/u6mtG98BToo/p7RcxkocvvoabxvfjotYTGNiqNw9e15/fXdatQcdSnyYNEv37\n9ycqKrBKN3fuXGbPnt2Uu24Sdre9KnIb6N0xAeeBC3Adb+tdrrlCvEMOnarenhlr9TUd1CwYAdKi\n/augQ7q05t5xk8iwdqBP0gW0T9XXj4uy0qO9fzt1sKvjtlGtMRtMXD6iHWN6tw1YnhQdGXQiN4BB\nyb4r5R4J3fDs0a8WQ0P8awmpCXqB2yY5kumj2gNgtRgZETfB7yqxf4cMbz+GBzeDu6eQFOYbpB5l\n8b9ybZ/oX2DfNLEXv+47ngviM5nQZQDBJMWE8vQ9o2mbFNjOnmpNp/r0DgtSWA3skkqYOXgfUBK+\nK7wL4rsywBTYnFM7Xyod+iibET1bcUWnyd6bvQCiatSaYqPM/OWGAXRKTA1Y3iEtmj6dEompdZ7E\nhUXgzm1N97jMeqd+iQg1c8flvQPe7xjTznvDWkiQIHHLZb2wBAsewAU19ucpTmRm35v58zV9uWRw\nG349rpN32dSR7bk6yFQal3e4xG9Yc3SN731YD72wbxXh++7jQ/2bSaNqTZvisBm5od/FhDvTCPf4\n1xTiIvVjsFTdfKcFuVDqUDMvatSqPTYXBRuzQDXSvU3w/q4orSpgadA9PpPf9fwNf/vtQB67eRCu\nnO/p3ymW1sn+zaa3XX4BUWEWRvdJ49J245na4TLvsh6tfenXtKo01xiOXvO8iQg1k5nu33wXbg7D\nfbwdnqIkojTftqaN7MDvJnUnIymCzDaBv41Tddbvk1ixYgWtWrWiS5fgTQTnMpvHjuYxcdnQNsRG\nhnD7qEvILark43J9UrBhXdpy1LoHm60Mp0cPEgbFF4drF4yt4xOgxs2wVpOVWGsMfx76+5OmRfUE\nxvekqnbq0BATaXHRUKsp1WwwB5351J2bQaeu7diQq88rc2vPG3js500cKC7lRLF/dTU1Xj+JI8LM\nXDSoNf0zk/zu8PyialqQoZltWZZnxoXHe5ezuUaAql0QJkf6t8P3aJNCpCWC23rdVHcmVAl2TB0T\n0zhgVOjRPh6rMbBJM8RkCRowx2SMQHWGcCJ3PwC39bqRNVuP8x07SY4LI7ewsupY/PPfVhUkwqwm\nLswYDsDyg/oNT5E1xqxXXzyEmnxBPrrWeRFaK3g98KuhoBmJiTh5U2HbpFjwv6mYpDBfIRwsSMRH\nhpNlD8yLi9qOxaN62F6wC0014NzTj/bTWkMMdM7w/74mDW0bND3j24wG4KP9+rQU1hpX7waT3u6f\nFOEryOLC/M8Lk1KroHeb6ZPRgT4Zd1FYaueTdYcoLHOwdX8B7VNr9cMFCRLJfnnhO2+Of70fZ5GN\nrG9e5aeSNjAklBNrjlC8IxfNozFw6BC6TEjj+2NZZL2/i0r1GBvUL7n++pspLMynoqyIH5bPZ+/a\nGJ555j/e7V7QLp7UhHAsJiOXtBvPV199zu4X9PuMLhiXAfGgqRrH132GpaKcUKuJskojiUMyWLZk\nid904Rf/fprfsViMFq4Y2Jf3V0YzaNiZzw9Wl7MaJOx2Oy+88AKLFi3yvncejcDFoerjvKunCejd\nSb8yzv2lH9sLdnLjhT35pdDK81sW+rVDD08dxPaCXQFBIikyBqshDLtaidUYQmxIYGdzXRKjwqFW\n2RcT4vvhBmtW0DQ1aHNT/05paJr/SJZRvVI5kF1KXJR/gZWaqBd4kaFmjAZDwBQAvRN7cLj0KKHm\nEG7ocSUvbXuNETVGeQxI7sPBksN+hSRArDWWMFMolW4bEebwU5pu3BLkhrGksASeu2sgRqPiDX41\nGRVj0OASbg7HYFGgxqweQy9IoazSyYCuScz+z3oAbE7/+wc6pkWz+2gxrZN833G3uC4U2IswKAau\n73Y1r/7yDgNTfM2EvRIvIK8y3y94AqSEJWMxWnB6nMSERBMdFtrgyRCDHVPNGlywPqkQoyXoeRFu\nDvNODKmgMLBr3SPKTqZDdDtcVQHy6i7TeGf3Ynom6E1ukeEWPMUJ+n00VjOL933C5hPbAPzOS01V\nCOm1iofW1Zi7KgrUCI3YeA87Qow8tE7Pp5ThGk63G1utAVo1A2bN30ir8R2wn6jgv68uZEPuJj78\n5kMchZV0/v0ANE3jxNLjxOxPpPRoAdaYMP47T7+TvLKygrCwcN59922ee+7FoC0n1fLz83nhhX8z\n8s5LiYyIZNfrW+macgE/VmynbXQoz73yKgZF4R9rn8YaFsabz/pPF+4xaygoaGjeAQgXDWrNhX3T\nsFqarig/q0HiyJEjZGVlMWXKFDRNIzc3l+nTp/P+++8TH3/yeSESExveudbYVE3FqTrQPGHExYb5\npeXeUTd7/05K6s/ozP5+n52ZeEOd231txrzTSs8V4zP5oNaDtNITEr3piqoMbI4Kj7R4O4Nr6tI6\nkXCr7weTmBjJ1LGdiYsNo0fHBOKjfdsaGGml87pDDOuTHvT7mDPmD96/xycOYXy3IX7L/zj6ljqP\n6ZXpwaesPplWyYFV6tZJyaSn6UEzrDwwiERFhxBvCPxBJ8fGUOH01Z6qj/G6SfpAhTk3DODtL3cz\nYUg7IsJ8BetDNw/mx125XNgvw1ugPzL+bu/ySxNHcWmPGh3YwANjbq/jiCJ544pn6lhWv7TkwFE0\nHVulk5igH4fxROBFWWxsGDFB+pZaxcXjKKq6i91i4qFfDQlYp6H+fpGveXla4nim9fY1x4VHWnHu\n0X8z6VfHEFZhwWioDopGEsPi0DQoKncQGm6qsaxqDYOC2WQIeM9kMmOrNWo9M60NidF6XriPBd40\nmZYcR3hZCGX7CynfX8ie/2xE0yDaGEFFUQmhyRGc+Oogr722gFGjRtG/v55ugwHi48OJiQn8TZjN\nRmJjw8jOPsDQoUN4YsajAHxQ8QH79+/nrdue54rPruClBc8yatQo/jnlQRRF4Xfv7+Lvf3+YcePG\nMW7cOMLCwnj3qucbkNuNq8mDRM2aQufOnVm7dq339ZgxY1iyZAnR0Q27gm7Om+mqH4mIx4Tb4W7W\ntCQmRlJQEDivTqIh2Zsui0sv2N0FKWSkWjjuOEKoO5KcysB5nBw2D1FVtZBu8V282+jeOgbVGXis\nf/61fjXc3Dc3gp4X+fmBeRGlxnrTF1o1jHJoqwHsKThKvjMHg91KaVngyA9bhYd4q3612S+pV8Ax\ndkyJ5KHr+mOrcGCr8B9336NNbNC0nC2JiZEUFQQek8UZ7j2OKEUPqGMyRrA9fycnbPm4yg0Ulgam\n21WpkWLR+076J/Vusu+7ZhlRUWbjorQJXJR2ZuP7ExMjOZ5bxF0r5/i9b7BZyat6/kiCUf+eL2k7\njq926M2kFSUuCstKQdNIG92RqD567emuPr+nwFbIG7ve56a/3Ul8bgT/+Mc/GThwMDfccDOqqlFQ\nUI7LFdjE5XJ5KCqqpKSkEpvN6c3HsjI7lZVOHA6FhQvfZMOG9fz3v6+yZMnH3H//X3j88Xne6cKf\ne+7fvPHG+6f1DIkzvbhu0iBx7733smHDBoqLixk9ejR33nkn06f75iRRFOW8aW6yVwUJzW0ixNLw\nIXJnw+UdLqFnQjeSw33NAe2i22DfPhStMoJx3buSlqGQHpnK4bKjAZ83Goy0i27DnwfcRXI9N6md\nD67oNJnu8V38bhrLjOvEnwbMJD0iFafHRYG9kMSweOwnAm8iNBmMdI3rzJ/6z6RVxJlNZ9AcajZL\n/brLdDrHdiTC7Ksl9Erozp8GzCQjIo1L2o2jyF5CdEik9/w2G8zeZiGTwUSPuM7M7n8naRHBh5M2\ndppNxsYbS2My+Iq367peRYeYdn79UANS+tAqIpmMiDT6RffkD69sw2qyYnc7iOwYT/7KI4R3j8do\nMWIrrqBLdHt+1/4auqR0JrRXKKGhVj77TJ+BISwsnIqKCqKi6r7g7dbtAp599ilKS0sID4/g66+/\n4IorrqakpBiz2cyoUReSmprG3//+V8A3XXiPHr34+usvsNkqm+VZ2k0aJObNq78p5Ztvgk8ydi6q\nflANHrN39MS5IsoS6RcgqoW4Y7HjIToslIxIvRmiui8gMTQeDci3FXgL1GA3Wp1vokOi/Nqdq1U/\n+MdqCvEWeNWFZ5vIDPJtBVS4K70d6q2jmq4j8GyJsUaTGObfjKsoijcvQk2hhEbo50N1f1mH6Lbs\nLtqHhkakJRJFUWgTFfzei8aUnhiuT/rYROKsMSSE+jfFGRSDNy9S4lPo3asv119/NfFdWhE5OI5E\nezRbXtoEwHPxOfztkb/jPGHj1odvwlDVnHXfffcDMHny5dx330wSEhL9Oq7BFwTj4xP4/e9v5847\n9YEpQ4YMZ/jwkezbt5e5c/+KpqkoisKtt97pN104aFx11TXNEiBAZoFtsOrmJs1j8nZcnyvqmvX0\n4RsHsHV/gd8wuD6JPbi6y1R6JlyAqnnYUbCLXgndg37+fBT0PoA6DEkdgEtz0z+pNw6Pg91F++kc\ne/pPRDvXBBvJVJcxGSMwGUwMbtWPMmc5B0uOkBF59u5j+ttv655NoDHU9Rup6S9/0fsKXKqb1Vnr\nGTZqEEW/LeJYWTb9U/oAkJqaxsCBgwM+O336VUyfflXQ7T777Avev8eNm8i4cRP9lnfs2IlFi94I\n+Fz1dOHNTYJEA9mqaxJuE9ZzrLnJbAj+NSbHhjG+v/8oIUVR/EYbDU8LPOHPZ6dyR7lBMTA6fRgA\nEYQzNLRhUyecL04lYBoNRu/Q3VBTaNDa2PnsVPLCbDAxJmMEACnhyaSEN2x6kf9VMndTA9WsSZxr\nzU1K0McNtUwmRa57qp1KTeJ/neTF6ZMg0UB27+gmMyHmcyvbGjqGviWQvPAxGyVgVpMgcfrOrdLu\nHFZUdVOR5gpp0htXTsWENvrso9Wdby3ZqKpmo8TQIM+nbGEGt9LH7tcc1dRS9U3qiclgqnM6GnFy\n8tChBlqw9VW25O/A9tOFvHj3BMym5mtySkyM9OaFqql+U3+0NJIXPpIXPpIXPmd6n0TLzblTdLQ0\nB8VjRvFYGnUs95lqySd/bZIXPpIXPpIXZ0ZyrwFKKuwU2AtxV4bTNiVa2r2FEC2GBIkGyC4tQDFo\naI5QxvQ9/284E0KIhpIg0QA2pz4RWKg5hCEXnH9TNQghxOmSINEAdrc+XUByTDgGaWoSQrQgEiQa\nwO7yTXgmhBAtiQSJBnC4q56sFuThNkII8b9MgkQDVNck5A5WIURLI0GiAZxVfRIWCRJCiBZGgkQD\nODzVQUKam4QQLYsEiQZweqQmIYRomSRINIDT7QYgxGQ5yZpCCPG/RYJEA1TXJELM0twkhGhZmjRI\nzJkzh6FDhzJp0iTve08++SQXX3wxU6ZM4c4776S8vLwpk3DaDpcepdJVCYDLo9ckrNInIYRoYZo0\nSEybNo2FCxf6vTd8+HCWL1/O0qVLadOmDS+++GJTJuG0fPLjLzy56Tn+svopnl+yDZeqBwmL1CSE\nEC1MkwaJ/v37ExUV5ffe0KFDMRj03fbu3ZucnJymTMJpWbJ2NwA2Stm0O88bJKzSJyGEaGGatU/i\ngw8+YOTIkc2ZhAZxVt1xHWaRmoQQomVptjGd//nPfzCbzX79FSdzpk9Yaij/ViUNZ1WfREpi7FlL\nw8mcK+k4F0he+Ehe+EheNI5mCRJLlixh1apVvPbaa6f0ubP1+FLFqPpemFzeaTls5U7ylOZ7hGq1\nmo9mbOkkL3wkL3wkL3zONFg2eZCo/Qjt7777jpdffpk33ngDi+XcbOM3GHxBQrHYcHpcmACzzAIr\nhGhhmrTUu/fee9mwYQPFxcWMHj2aO++8kxdffBGXy8VNN90EQK9evXjkkUeaMhmnTFVUb2eNEmKH\nqqAhU4ULIVqaJi315s2bF/De9OnTm3KXjULV3N4gYQipBKU6SEjHtRCiZZE7rmtRNQ0Vj/e1ElKJ\nUlWTMBuMzZUsIYRoFtJ+UovHo3qblwAUayWKUR/dJM1NQoiWRkq9Wlxuzdu8BGCMLvD+LUFCCNHS\nSHNTLW6P6m1eqs2gSHYJIVoWKfVqcXtUMOh9EjVH75qVc3O4rhBCNCUJErW4PKqvucntG800NfU3\nzZQiIYRoPhIkanG7fc1NmttXe+iakdBcSRJCiGYjQaIWt8fXca3VqElYTSHNlSQhhGg2EiRqcdUc\nAlsjSFgM0ichhGh5JEjU4qnZcV2juckiT6UTQrRAEiRqcXlUlCDNTTL8VQjREknJV4vbrXmbm6YP\n69rMqRFCiOYlQaIWd40+iQhzWDOnRgghmpcEiVqq75NQMMiIJiFEiydBohb9PgkPRsVIiFGChBCi\nZZMgUYvbo4LRjVmxYJbnRwghWjgJErW4PBqK0U2IIQRFae7UCCFE85IgUYvL7QGjG4tBmpqEEEKC\nRC02pxPFoGE1WkkJTwagT2KPZk6VEEI0jyZ9is6cOXNYuXIl8fHxLFu2DICSkhLuuecesrKySE9P\n5+mnnyYyMrIpk3FKyl2VYIBQs5UoSyT/HPFXGeUkhGixmrQmMW3aNBYuXOj33oIFCxgyZAhffPEF\ngwYN4sUXX2zKJJyyCqcNgDBzqPd/udtaCNFSNWnp179/f6Kiovze++abb5g6dSoAU6dO5euvv27K\nJJyySpcdgAhLaDOnRAghmt9Zv0QuLCwkIUF/NkNiYiJFRUVnOwn1srklSAghRLUm7ZNobImJTd93\n4dIcACTHxZ6V/Z2uczltZ5vkhY/khY/kReM460EiPj6e/Px8EhISyMvLIy4ursGfzcsra8KU6aqb\nm9x25azs73QkJkaes2k72yQvfCQvfCQvfM40WDZ5c5OmaX6vx4wZw+LFiwFYsmQJY8eObeoknBKn\nqtckQk3WZk6JEEI0vyYNEvfeey9XX301Bw8eZPTo0Xz44YfccsstrFu3jokTJ7J+/XpuueWWpkzC\nKatubgo1SpAQQogmbW6aN29e0PdfeeWVptztaXO5VTSDC4AwmSZcCCHkjuuabA43mKqChElGNwkh\nhASJGmxON0p1kDBLkBBCCAkSNdgcbjC6QFOwyrMkhBBCgkRNNrtekzArISgyT7gQQkiQqKnS4UEx\nurEoMrJJCCFAgoSfSrsLTC55bKkQQlSRIFFDudOOYlAJNUqntRBCgAQJP2WOCgDCTHKPhBBCgAQJ\nP2VOPUhEWCRICCEESJDwU+4sByA6RGaPFEIIkCDhp8Kj1yRiQ6ObOSVCCHFukCBRg60qSMSFRp1k\nTSGEaBkkSNTg0CoBiJUgIYQQgAQJPy7FBkifhBBCVJMgUYPHoD+VLtIc0cwpEUKIc0ODgsSnn35K\nebk+8ueZZ57ht7/9Ldu3b2/ShDUH1WgDjxmz0dzcSRFCiHNCg4LEf/7zHyIiIti6dStr1qzh8ssv\n57HHHmvqtJ1VTo8LzVKOySn9EUIIUa1BQcJk0h9gt3btWmbMmMGkSZNwOBxNmrCzLacyFxQwu2Oa\nOylCCHHOaFCQUBSFjz/+mOXLlzNkyBAAXC5XkybsbMsqzwEgxCNBQgghqjUoSDz44IN8/vnnzJgx\ng4yMDA4dOsSgQYPOaMevvPIKl112GZMmTeLee+/F6XSe0fbOVKGtCIAQTUY2CSFEtQYFib59+/L8\n889z/fXXA9C2bVseeuih095pbm4ur7/+OosXL2bZsmV4PB4+/fTT095eY3B53ACYDKZmTYcQQpxL\nGhQknnjiCcrKynC73fz617+md+/eLF269Ix2rKoqNpsNt9uN3W4nKSnpjLZ3plweDwBmo7FZ0yGE\nEOeSBgWJdevWERkZyZo1a0hOTuaLL75g0aJFp73T5ORkbrzxRkaPHs3IkSOJjIxk6NChp729xiBB\nQgghAp1S28oPP/zA+PHjSU5OPqNnQJeWlvLNN9/w7bffEhkZycyZM1m2bBmTJk2q93OJiU3XX2C0\n6McTHmpt0v00lvMhjWeL5IWP5IWP5EXjaFCQiI+P58EHH2Tt2rXccsstuN1uPFVX3qdj3bp1ZGRk\nEBOjjyQaP348mzdvPmmQyMsrO+19nkx5hT6kV3VrTbqfxpCYGHnOp/FskbzwkbzwkbzwOdNg2aDm\npnnz5tGxY0fmz59PdHQ0OTk53Hjjjae909TUVLZs2YLD4UDTNL7//ns6dOhw2ttrDG5V77i2mKTj\nWgghqjWoRIyLi+M3v/kNBw8eZN++fbRt25Zp06ad9k579uzJxIkTufzyyzGZTHTr1o0rr7zytLfX\nGNyq9EkIIURtDQoS27ZtY+bMmVgsFjRNw+1289xzz9G9e/fT3vEdd9zBHXfccdqfb2zVQcJilJqE\nEEJUa1CJ+PjjjzN37lzv3dbff/89jz76KO+8806TJu5s8kiQEEKIAA3qk7DZbN4AATB48GBsNluT\nJao5eFQVkD4JIYSoqUFBIjQ0lO+//977euPGjYSGhjZZopqDt7lJgoQQQng1qEScM2cOd911FxaL\nBdAn93v22WebNGFnm0eT5iYhhKitQSViz549+fLLLzl48CCaptGuXTsmTJjAypUrmzh5Z4+3ucks\nQUIIIao1uEQ0m8107tzZ+1rTtCZJUHOprkmESHOTEEJ4nfYzrs9kWo5zkaqpaBqESE1CCCG86i0R\n9+3bV+cyt9vd6IlpTh7NA5qC2XTacVMIIf7n1BskbrnlljqXhYSENHpimpOqqaAZJEgIIUQN9QaJ\nFStWnK10NDsVVa9JGCVICCFENSkRq+g1CWluEkKImqRErKJR3dwkE/wJIUQ1CRJV9NFNUpMQQoia\npESsokmfhBBCBJASsUp1kDCZ/rfu/xBCiDMhQaKKhgoYMBokS4QQopqUiFU0RUOR7BBCCD9SKlbR\nUFE0aWoSQoiaJEhUU1SpSQghRC3NViqWlZUxc+ZMLr74Yi699FK2bNnSXEmpomFQJEgIIURNzTbl\n6eOPP86oUaN49tlncbvd2O325kqKfre1gtQkhBCilmYpFcvLy9m0aRPTp08HwGQyERER0RxJAcCj\n6eCf+HkAABL8SURBVA8cMkiQEEIIP81SKh47dozY2Fjuv/9+pk6dykMPPdSsNQmn2wUgzU1CCFGL\nojXDI+a2b9/OVVddxTvvvEOPHj14/PHHiYyMZObMmWc7KQDc/I9PKW27jAhXOot+80CzpEEIIc5F\nzdInkZKSQkpKCj169ABg4sSJvPzyyyf9XF5eWZOkJ7ewgtC24HY13T4aU2Ji5HmRzrNB8sJH8sJH\n8sInMTHyjD7fLO0rCQkJtGrVioMHDwLw/fff06FDh+ZIik7R+yQczv+t53YLIcSZarbRTQ8++CD3\n3XcfbrebjIwM/v73vzdXUlAUPTi43BIkhBCipmYLEpmZmXz44YfNtXt/VUEiKvR/65GsQghxpmQ4\nD2C16v/3bJ/YvAkRQohzjAQJwGPQh9/GhUU3c0qEEOLc0uKDhKZp3iARZWm+G/qEEOJc1OKDhKpp\nYHICEGk5s6FiQgjxv6bFBwm3R0MxOwCIkiAhhBB+WnyQ8HjUGkFCmpuEEKKmFh8kXB4NxSzNTUII\nEUyLDxJ6TcKJQTMRYrQ0d3KEEOKc0uKDhNujgtGNEXNzJ0UIIc45LT5IuDwaisGDQYKEEEIEaPFB\nwuNRweDB2HwzlAghxDmrxQcJt0cDgweTIjUJIYSorcUHCYfbhWLQMCpSkxBCiNpafJCwu/R7JEzS\nJyGEEAEkSLj1eyRMBqlJCCFEbS0+SDiqgoRZkXskhBCithYfJOwevbnJbJDmJiGEqK3FBwmnxwVI\nkBBCiGBafJDwNjdJkBBCiADNGiRUVWXq1KnceuutzZYGp6oHCZm3SQghAjVrkHjttdfo0KFDcybB\n29xkMUpNQgghamu2IJGTk8OqVauYMWNGcyUB8NUkLMaQZk2HEEKci5otSMydO5fZs2ejKEpzJQEA\npyo1CSGEqEuz3EG2cuVKEhIS6Nq1Kxs2bGjw5xITG/+hQIrRA25IiI5qku03lfMprU1N8sJH8sJH\n8qJxNEuQ+Omnn1ixYgWrVq3C4XBQUVHB7NmzefLJJ+v9XF5eWaOnpcJhB8DtUJtk+00hMTHyvElr\nU5O88JG88JG88DnTYNksQWLWrFnMmjULgI0bN7Jo0aKTBoim4lLdYACrWUY3CSFEbS3+PgmP6gbA\napIgIYQQtTX7rHYDBw5k4MCBzbZ/t+YBIFRqEkIIEaDF1yTc1TUJCRJCCBGgxQcJD3qQCLVIkBBC\niNpafJBQpblJCCHqJEECPUiEyM10QggRoMUHCY+mNzfJk+mEECJQiw8SGh5QDc0+PYgQQpyLWnyQ\nUBUPaC0+G4QQIqgWXzpqeFA0Y3MnQwghzkkSJBQVBQkSQggRTIsPEigqBgkSQggRVIsOEm6PCgZp\nbhJCiLq06CDhcqtgUDEoEiSEECKYFh0knC4PikHFKM1NQggRVIsOEjaX/nxroyI30gkhRDAtOkjY\nXfrzrY3S3CSEEEG16CBRWqk/utRskHmbhBAimPMmSLg8LrLLcwBQNRW724GqqWe0zRKbDYAQkwQJ\nIYQI5rxpjL9l8UNUqCVYK9NwhRTgMdoxYuKC+K5M7XQJiWHxp7zNUptek5AgIYQQwTVLTSInJ4fr\nrruOSy65hEmTJvHaa6+d9DMVagkA9rAsPEY7amUELruZLQXbeHLj8xTaik45HcW2cgDCLaGn/Fkh\nhGgJmqUmYTQauf/+++natSsVFRVMmzaNYcOG0aFDhzo/MzTqUoa37cH3ew6SEBnJ/7d390FR1f8e\nwN+7KynyoCIrGJKDOPhTygdMsOCiFwkMQXYn0IlxakbNMgt5SMKdUeeq6Uw4zNRtHDMrs7g5eUt/\nU/izudH4dMW1SLQGLdExWIpdEZAnZV32c//gsoayiLl4kH2//trztPs9n+Hw3u+ec76nuXEIrDc7\n8H3NEbQF/YatJ7cjL+qVe+pRNLY3AQD8ho28730iIhqMFAkJrVYLrVYLAPDy8kJoaCgsFkuvIZH1\nbDKuXGnGeH+/bvP/vSEIm//nv9Dm/xt2nP4Mhqdeg0bdebWSpa0OgGDMcK1jfXOrBaOGjcIjGg80\nW5uAR4DRwxkSREQ9UfzEtclkwvnz5zF16tS/tb121HBk/dsi2K+NRm17DTaWFuJi42WUmc9gk3Eb\n/uNkAX6trwQAXL5WhY3GbfjvC/8EALTYOn9uCvBmSBAR9UTRE9etra3IzMyEwWCAl5fX336fkLG+\nmD92IQ5W/wt1o2tR+NP2bsvfLd+JkCHTUWW+BowG/vePU8j4RxpuSCsAQOvFkCAi6oliIWGz2ZCZ\nmYnU1FTEx8f3aRut1sfpsuUpkZhW8RgKSoogI6vQcU0LsQ4D7GoMefQSLtnPAD4e6Hr+nHXITdyw\nd4ZE6LggDBsy9H536YHqrRbuhrW4hbW4hbVwDZWIiBIfnJeXh1GjRmHt2rV93ubKlea7rtN24yaq\nLS34s74NlaZrmBg0Akfr/wWz6rfuK3YMATQ2qMUD/znvrXttvqK0Wp8+1cIdsBa3sBa3sBa33G9Y\nKtKTKCsrw9dff42wsDDodDqoVCpkZ2cjNjb2vt97+DAPTHpsFCY9NgpzpwcBAOzVk/Dlhc6QiBoT\nBaPFCGhsAABP9d//mYuIaLBTJCRmzpyJc+fOPbDPG+sV4HidEBLTGRL/z8+T5yOIiJx5aO64vh9h\nI0Mxd1w0IgMjoPX077Ys0MfPyVZEROQWIaFRa5AeluqYHj1sFK7e6LxDe8RQntwiInJG8fsklGCI\nzHG8fkTziIItISIa2NwyJP56uatCF3cRET0U3DIkACBxfBwAIHz0JIVbQkQ0cLnFOYmeJE9IwJxx\n0TwnQUTUC7ftSahVagYEEdFduG1IEBHR3TEkiIjIKYYEERE5xZAgIiKnGBJEROQUQ4KIiJxiSBAR\nkVMMCSIicoohQURETjEkiIjIKYYEERE5xZAgIiKnFAuJo0ePYv78+UhMTMTOnTuVagYREfVCkZCw\n2+3YtGkTPvzwQ3zzzTcoLi7GxYsXlWgKERH1QpGQOHv2LMaPH4+goCB4eHhgwYIFKCkpUaIpRETU\nC0VCwmw2Y+zYsY7pgIAAWCwWJZpCRES9UCQk+FxpIqKHgyKPLw0MDMQff/zhmDabzRgzZsxdt9Nq\n+SS5LqzFLazFLazFLayFayjSk3jiiSdQVVWFmpoaWK1WFBcXY968eUo0hYiIeqFIT0Kj0WDdunVY\nunQpRARpaWkIDQ1VoilERNQLlfAEAREROcE7romIyCmGBBEROcWQICIipwZ8SLjjGE8GgwFPP/00\nUlJSHPOuXbuGpUuXIjExEcuWLUNzc7Nj2ebNm5GQkIDU1FScO3dOiSb3i9raWrzwwgtISkpCSkoK\n9uzZA8A9a2G1WpGeng6dToeUlBS89957AACTyYRFixYhMTEROTk5sNlsjvWzs7ORkJCAxYsXd7vk\nfLCw2+3Q6/V45ZVXALhvLeLi4rBw4ULodDqkpaUBcPExIgNYR0eHxMfHi8lkEqvVKgsXLpTKykql\nm9XvfvjhB6moqJDk5GTHvLffflt27twpIiLvv/++FBQUiIjI4cOH5aWXXhIRkfLycklPT3/wDe4n\nFotFKioqRESkpaVFEhISpLKy0i1rISLS1tYmIiI2m03S09OlvLxcVq9eLQcPHhQRkfXr18vnn38u\nIiJFRUWyYcMGEREpLi6WrKwsRdrcnz7++GPJzc2Vl19+WUTEbWsRFxcnjY2N3ea58hgZ0D0Jdx3j\n6cknn4Svr2+3eSUlJdDr9QAAvV7vqENJSQl0Oh0AYNq0aWhubkZdXd2DbXA/0Wq1mDx5MgDAy8sL\noaGhMJvNblkLAPD09ATQ+c3YZrNBpVLBaDQiMTERQGctvvvuOwDd/14SExNRWlqqTKP7SW1tLY4c\nOYL09HTHvJMnT7plLUQEdru92zxXHiMDOiQ4xtMt9fX18Pf3B9D5z7O+vh4AYLFYEBgY6FgvICAA\nZrNZkTb2J5PJhPPnz2PatGm4evWqW9bCbrdDp9MhOjoa0dHRCA4Ohq+vL9TqzsM4MDDQsb9/rYVG\no4Gvry8aGxsVa7urbdmyBXl5eVCpVACAhoYGjBgxwi1roVKpsGzZMjz33HPYt28fALj0GFHkZrq+\nEt7CcVc91ajrwBksWltbkZmZCYPBAC8vL6f7N9hroVarceDAAbS0tGDVqlU9Dq/ftb+310JEBk0t\nDh8+DH9/f0yePBlGoxFA5/7dvs/uUAsA2Lt3ryMIli5dipCQEJceIwM6JP7uGE+D0ejRo1FXVwd/\nf39cuXIFfn5+ADq/CdTW1jrWq62tHVQ1stlsyMzMRGpqKuLj4wG4by26eHt7Y9asWThz5gyamppg\nt9uhVqu77W9XLQICAtDR0YGWlhaMGDFC4Za7xk8//YTvv/8eR44cQXt7O1pbW7FlyxY0Nze7XS2A\nzp4CAPj5+SE+Ph5nz5516TEyoH9ucucxnm5P/Li4OHz11VcAgP379zvqMG/ePBw4cAAAUF5eDl9f\nX0c3czAwGAyYOHEiXnzxRcc8d6xFfX294wqVGzduoLS0FBMnTkRUVBQOHToEoHst4uLisH//fgDA\noUOHMHv2bGUa3g9ycnJw+PBhlJSUoLCwEFFRUdi2bZtb1uL69etobW0FALS1teH48eMICwtz6TEy\n4IflOHr0KN566y3HGE8rVqxQukn9Ljc3F0ajEY2NjfD398frr7+O+Ph4rF69Gn/++SceffRRvPPO\nO46T2xs3bsSxY8fg6emJrVu3Ijw8XOE9cI2ysjIsWbIEYWFhUKlUUKlUyM7OxtSpU5GVleVWtfj1\n11+Rn58Pu90Ou92OpKQkrFy5EtXV1cjJyUFTUxMmT56MgoICeHh4wGq1Ys2aNTh37hxGjhyJwsJC\njBs3TundcLlTp07ho48+wo4dO9yyFtXV1XjttdegUqnQ0dGBlJQUrFixAo2NjS47RgZ8SBARkXIG\n9M9NRESkLIYEERE5xZAgIiKnGBJEROQUQ4KIiJxiSBARkVMMCXroLFq0CHq9HgsWLEB4eDj0ej30\nej0MBsM9v9fy5cv7NHT02rVrUV5e/neae08qKirw7bff9vvnEPUV75Ogh1ZNTQ3S0tJ6HdWza5iG\nh8W+fftQWlqKwsJCpZtCBGCAj91EdK9KS0tRUFCA6dOno6KiAqtWrUJ9fT2KioocD6HJz89HZGQk\nAGDOnDnYvXs3QkJCkJGRgRkzZuD06dOwWCxITk5GVlYWACAjIwOvvvoqYmJisGbNGnh7e+PixYsw\nm82IiIjA1q1bAXSOhZOXl4eGhgYEBwejo6MDcXFxWLx4cbd21tXVITc3Fw0NDQCAmJgYLF++HNu3\nb0dbWxv0ej2ioqKQn5+P06dPo7CwENevXwcAZGZmIjY2FlVVVcjIyEBycjLKyspgtVqxYcMGRERE\nPJBak5u4n4ddECnJZDLJ7Nmzu807ceKETJkyRX7++WfHvL8+kKWyslLmzp3rmI6NjZVLly6JiMjz\nzz8vubm5IiLS1NQkkZGRYjKZHMuOHTsmIiJvvPGGLFmyRG7evCnt7e0yf/58MRqNIiKycuVK+eCD\nD0REpLq6WmbMmCF79+69o+27du2S9evXO6abmppEROSLL76QnJycbm3X6XRy9epVERGpra2V2NhY\naWlpkd9//10mTZokxcXFjn2fO3eu2Gy2vheR6C7Yk6BBZ8KECXj88ccd05cvX8a7774Li8UCjUYD\ni8WCxsZGjBw58o5tn332WQCAj48PQkJCUFVVhaCgoDvWe+aZZzBkSOfhM2XKFFRVVSEyMhJGoxGb\nN28GAIwbN87RY7nd9OnT8dlnn2Hbtm2YNWsWYmJielyvrKwMJpMJy5Ytcwz6qNFoUF1djeHDh8PT\n0xNJSUkAgKeeegoajQaXL19GaGhoX8tF1CuGBA06Xl5e3aazs7OxYcMGzJkzB3a7HVOnTkV7e3uP\n2w4dOtTxWq1Wo6Oj457W6+tzCmbOnIn9+/fjxIkT+PLLL7Fr1y58+umnd6wnIggPD8fu3bvvWFZV\nVXXHPLvdPqielUDKe3jO6BH1QPpw3UVLS4tj1M+9e/c6/cfvCpGRkY4hmmtqanDq1Kke1zOZTPD2\n9kZSUhLy8/Pxyy+/AOh8VsRfH1ofERGByspK/Pjjj455Z8+edby+fv06Dh48CKDz8Z0AMH78eNfu\nFLk19iToodaXb80GgwErVqzA2LFjERUVBR8fnx63v/29nC3rbb1169bhzTffRHFxMSZMmICIiIhu\nn9eltLQUe/bsgUajgYhg06ZNAIDo6Gh88skn0Ol0mD17NvLz87F9+3YUFBSgubkZN2/eRHBwMHbs\n2AEA8Pf3x4ULF5Ceng6r1YrCwkJoNJq71oSor3gJLJELtbe3w8PDA2q1GmazGenp6SgqKkJwcLDL\nP6vr6qbjx4+7/L2JurAnQeRCly5dwtq1ayEisNvtyM7O7peAIHpQ2JMgIiKneOKaiIicYkgQEZFT\nDAkiInKKIUFERE4xJIiIyCmGBBEROfV/smX5vm0Z6kkAAAAASUVORK5CYII=\n",
+ "text/plain": [
+ "\u003cmatplotlib.figure.Figure at 0x7f970d490590\u003e"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "test_accuracy 0.1\n"
+ ]
+ },
+ {
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYwAAAEcCAYAAADUX4MJAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzsnXl4FFW6/79V1Vv2BEhIAG/AuCAIsgjoCEFgFDSsio7I\n6Dg4l/GODgpu4wxcnWHEHYXBDQVllJ/LRUAhDKCgYd+XsCVhS0IWOnvSSXqtqt8f1V1d1V2d7iwN\nIbyf5+Ghu6q66tRJ1fmedznnMKIoiiAIgiCIILCXuwAEQRDElQEJBkEQBBESJBgEQRBESJBgEARB\nECFBgkEQBEGEBAkGQRAEERIkGARBEERIkGAQVx0HDhzAPffcc7mL0eEZOHAgioqKLncxiDaEBIOQ\nGT16NPr164eamhrV9kmTJqF3794oKSkBAPzlL39B7969cezYMfmYwsJC9O7dW/7+yCOPYNWqVfL3\njz76CGPGjMGgQYNw5513Ys6cOQCA8ePHY9CgQRg0aBD69OmD/v37Y+DAgRg0aBCWLl3qV8YlS5bg\nhRdeaNV93nrrrfjPf/7TrN98/PHHePfdd7Fv3z6MHDmyVdf34FtHHY3Dhw+jR48el7sYRBuiu9wF\nINoXPXr0QGZmJqZPnw4AyMvLg91uB8Mw8jEMwyA+Ph7vvfceli1bptquxZo1a7Bu3TqsWLECPXr0\nQGVlJbZu3QoAWL9+vXzcI488gsmTJ+P+++9v1T2IohiwLC0lKysLzz33HJxOZ5ufu73C8zw4jrvc\nxSDaEWRhEComTZqENWvWyN/XrFmDKVOm+B03ZcoU5Obm4sCBA0HPefz4cQwfPlzubXbu3BkPPPCA\n5rFNzVSzfft2fPTRR9iwYQMGDhyIyZMnA5CE5t1338W0adMwYMAAFBUVYfXq1bj33nsxaNAg3HXX\nXfjmm2/k8/haCaNHj8by5csxceJEDBkyBHPmzIHD4ZD319XVoaCgAH369MHMmTNRVlYmW0Hl5eUQ\nRRFLly7FXXfdhdtuuw2zZ89GXV0dAMDhcOD555/HsGHDMGTIEDzwwAOoqqrCu+++i4MHD2L+/PkY\nNGgQ/vnPf2re89NPP43hw4djyJAheOSRR3DmzBl5n91ux+uvv47Ro0djyJAhmD59ulzuAwcO4KGH\nHsKQIUMwatQorF27Vq4rpVWzZs0aPPzww/L33r17Y+XKlRg7dizGjh0LAHj11Vdx5513YvDgwbj/\n/vtVf3NBEPDRRx/hrrvukvebzWb5XBcuXJDr4Y033sCoUaMwfPhwvPLKK3JZq6ur8cQTT2DIkCEY\nNmwYfvvb3wZ8BojLCwkGoeKWW25BQ0MDzp07B0EQsHHjRkycONGvITeZTHjiiSewcOHCkM65du1a\nLFu2DMePH4cgCC0q24gRI/DEE0/g3nvvxeHDh+VGEADWrVuHf/7znzh06BBSUlLQuXNnLF26FIcO\nHcJrr72G1157DadOnZKP97USNm7ciOXLl2PLli3IyclRieaOHTtw2223wWQy4ZNPPkFSUhIOHz6M\nQ4cOITExEStWrMDWrVuxcuVKbN++HbGxsfj73/8OQGqQ6+vrsX37duzbtw9///vfYTQaMXv2bAwe\nPBjz5s3DoUOHMHfuXM17HjlyJH788Ufs2rULffr0wXPPPSfve/3113Hy5El888032LdvH55//nkw\nDIPS0lLMnDkTjz76KPbs2YO1a9eq3IW++NbF1q1bsWrVKmzYsAEA0L9/f/zwww/Yv38/JkyYgGee\neUZu7JcvX44NGzbg008/xcGDB7FgwQKYTCa/87711lsoKCjADz/8gM2bN8NsNuP9998HAHz22WdI\nTk7G3r17sWvXLsyePTtgWYnLCwkG4cekSZOwdu1a7Ny5E9deey2SkpI0j3vwwQdRWlqK7du3N3m+\niRMnYt68edi5cyceeeQR/OpXv9KMT7SGKVOmIC0tDSzLQqfTYeTIkbJFc+utt+KOO+5o0hp69NFH\n0aVLF8TGxmLUqFEqcfnll1+ajFt8++23eOaZZ5CUlAS9Xo8nn3wSmzZtgiAI0Ol0qKmpwfnz58Ew\nDPr06YOoqKiQ7+u+++5DRESEfN6cnBzU19dDFEWsXr0ac+fORWJiIhiGwYABA6DX67Fu3Trccccd\nuPfee8FxHOLi4poUDF/++Mc/IiYmBgaDAQAwYcIExMbGgmVZPPbYY3A4HDh//jwAYNWqVZg9ezZS\nU1MBADfeeCPi4uIAqK3FVatW4aWXXkJMTAwiIyMxc+ZM2R2p0+lQXl6OoqIicByHwYMHh1xW4tJC\nMQzCj4kTJ+K3v/0tioqKMGnSpIDHGQwG/OlPf8KiRYvwzjvvNHnO8ePHY/z48eB5Hj/99BOeffZZ\n9O3bF3fccUeblDk5OVn1PSsrCx988AHy8/MhCAJsNhtuvPHGgL/v3Lmz/DkiIgLl5eUApEZv165d\neOmllwL+tqSkBE899RRYlpV/o9PpUFFRgUmTJuHixYuYM2cOLBYLJkyYgDlz5oQUGxAEAQsXLsSm\nTZtQXV0NhmHAMAyqq6vhcDjgcDhwzTXX+P2utLRUc3uo+Nbl8uXLsWrVKrlOGhoaUF1dDQC4ePFi\n0GtVVVXBarWqYlOCIMiC8vjjj2PJkiWYMWMGGIbBAw88gJkzZ7a4/ET4IAuD8KNbt27o3r07tm3b\nhrvvvrvJY++77z5YLBb8+OOPIZ2b4ziMHTsWN954I06fPt0WxQWgdn84HA48/fTT+MMf/oDdu3dj\n//79SE9PbzI+Eohjx46hR48eSEhI8LuOh5SUFHzyySfYt28f9u3bh/379+PIkSNISkqCTqfDk08+\niczMTHz99df45ZdfZFdasOD5unXr8PPPP2PFihU4cOAAtm7dKt9DQkICjEYjCgsLNcujtR0AIiMj\nYbPZ5O8eEVCiLNeBAwfw6aefYvHixdi/fz/279+P6OhouRzJyckBr+UhISEBERERWL9+vVxHBw4c\nwMGDBwEAUVFRePHFF/HTTz/ho48+wueff449e/Y0eU7i8kCCQWiyYMECrFixQvZHB4LjODz11FP4\n5JNPAh6zZs0aZGVloaGhAaIoIisrC2fPnkX//v2bXa4uXbqguLi4ycbf6XTC6XQiISEBLMsiKysL\nO3fubPa1AMkdlZ6eLn/v3LkzampqUF9fL2/7zW9+g4ULF8ppx1VVVdiyZQsAYO/evcjLy4MgCIiM\njIROp5Otiy5dushBYS0aGhpgMBgQGxuLxsZGvPPOO3JjzjAM7rvvPrz++usoKyuDIAg4cuQInE4n\nJkyYgN27d2Pjxo3geR41NTXIyckBIAWiN2/eDJvNhoKCAnz33XdN3n9DQwN0Oh3i4+PhcDiwZMkS\nNDQ0yPsfeOABLFq0CAUFBQCA3Nxc1NbWqs7hsRoWLFiAqqoqAIDZbMaOHTvkOvaITmRkJDiOo+ys\ndkrYBWPbtm0YN24cxo4dG9BvvWHDBmRkZGDChAmqoB5xaVH2LK+55hr07dtXc58v48ePR1JSkl/q\nrYfo6Gh89NFHcjbPO++8g1deeQWDBg0KeP1AjBs3DqIoYtiwYbjvvvs0fxcVFYW//e1vePrppzF0\n6FBs2LABY8aMCXjOpq6blZWlil9ce+21yMjIwJgxYzB06FCUl5fjd7/7HcaMGYMZM2Zg8ODBeOih\nh5CdnQ0AqKiowKxZszB48GCMHz8ew4YNw8SJEwFIcZONGzdi2LBhePXVV/2uPXnyZKSkpCA9PR3j\nx4/HwIEDVftffPFF3HDDDZg6dSqGDRuGd955B6IoIiUlBUuXLsXy5csxdOhQTJkyRRaMxx57DHq9\nHnfccQdeeuklTJgwocm6GDFiBEaMGIGxY8dizJgxiIiIULmsfv/73+Oee+6R733u3LmyBaM813PP\nPYfU1FQ8+OCDuPXWWzFjxgzk5+cDAPLz8/HYY49h4MCBmDZtGqZPn44hQ4YE/JsQlw8mnCvuCYKA\nsWPH4vPPP0dSUhKmTp2KhQsXIi0tTT6moKAAs2fPxr///W9ER0ejqqoKnTp1CleRCCJkKisrMXny\n5KBBfYK4WgirhZGdnY3U1FR0794der0eGRkZsqnu4dtvv8XDDz+M6OhoACCxINoNFoulyWA3QVxt\nhDVLymw2IyUlRf7etWtX1XQSAGSzdNq0aRBFEU8++SRGjBgRzmIRREj07NkTPXv2vNzFIIh2Q1gF\nIxRvF8/zKCwsxMqVK1FSUoLp06cjMzNTtjgIgiCI9kFYXVLJycly5gggWRy+g8C6du2KMWPGgGVZ\n9OjRA7169ZKtjkCEMexCEARBBCCsFka/fv1QWFiI4uJiJCYmIjMz028qiV//+tfIzMzE5MmTUVVV\nhYKCgqADgRiGQXm5JZxFv2JITIyhunBDdeGF6sIL1YWXxMSYVv0+rILBcRzmzZuHGTNmQBRFTJ06\nFWlpaVi8eDH69euHUaNGYcSIEdi5cycyMjLAcRxeeOEFeWoBgiAIov0Q1rTacEI9BgnqPXmhuvBC\ndeGF6sJLay0MGulNEARBhAQJBkEQBBESJBgEQRBESJBgEARBECFBgkEQBEGEBAkGQRAEERIkGARB\nEERIkGAQBEEQIUGCQRAEQYQECQZBEAQREiQYBEEQREiQYBAEQRAhQYJBEARBhAQJBoA6hwX/PvkN\nNpz/8XIXhSAIot0S1vUwrhROVuZi78WDAIDR16TDpDNe5hIRBEG0P8jCAMCLvPxZEIXLWBKCIIj2\nCwkGAEGxhhQJBkEQhDYkGFCLhAASDIIgCC1IMOAjGGRhEARBaEKCAUBUiAQvkGAQBEFocVUKRo29\nFg7eKX/nFYIhkkuKIAhCk6tOMGwuO/6281W8tv9deZuoCHrz5JIiCILQ5OoTDN4GAChrrJC3KQPd\nIgkGQRCEJledYDAat6y0KsjCIAiC0KZDC0aj04ptRbtU8QqtGIWoypIS/fYTBEEQHXxqkK9yv8Oh\nsmzUOiyYcO1YANpZUOqBe7zffoIgCKKDWxhFlhIAQLkiXqFlYQhkYRAEQQSlQwuGRwhYxnubWjEK\nGrhHEAQRnLALxrZt2zBu3DiMHTsWS5cu9du/Zs0a3H777ZgyZQqmTJmCVatWtdm1BUjWglIwtARB\nmSVFgkEQBKFNWGMYgiBg/vz5+Pzzz5GUlISpU6dizJgxSEtLUx2XkZGBuXPntvn1PeMrGDDeMmla\nGDT5IEEQRDDCamFkZ2cjNTUV3bt3h16vR0ZGBrZs2eJ3nBimuIHXJeUVDK2pzMklRRAEEZywCobZ\nbEZKSor8vWvXrigrK/M7bvPmzZg0aRKefvppXLx4sc2u73E1MQFcUk7B5beNpywpgiAITcIqGKFY\nDqNHj8bWrVvx/fff4/bbb8eLL77Y5tcPFMNw8A73cd5tDXbvmA2CIAjCS1hjGMnJySgpKZG/m81m\nJCUlqY6Ji4uTPz/44IN4++23Qzp3YmJM8IPcnqjICIN8fAVM8u7YeCO6RMVAf46Tt8XEGEM7dzvi\nSitvOKG68EJ14YXqom0Iq2D069cPhYWFKC4uRmJiIjIzM7Fw4ULVMeXl5UhMTAQAbNmyBdddd11I\n5y4vtwQ9hhck95Ld5pKPr6qul/dfLK+G2KhHo9Uhb6uubQjp3O2FxMSYK6q84YTqwgvVhReqCy+t\nFc6wCgbHcZg3bx5mzJgBURQxdepUpKWlYfHixejXrx9GjRqFL774Alu3boVOp0NcXBxee+21Nrt+\nsHEYDncMQzmYT6D1MAiCIDQJ+9Qg6enpSE9PV22bNWuW/HnOnDmYM2dOWK7tSZdlGO20WqcgWRZK\nEXEJFPQmCILQokOP9BY9A/cQKOgtBbhpxT2CIIjgdFjBqKi1yu4lRjUOQ2lhSIIhqBZQIguDIAhC\niw4rGC98uFsWBzbASG+tcRgusjAIgiA06bCCAQAKw0JGaxyGQC4pgiCIoHRowfDgmYQQ8LUmtEZ6\nk2AQBEFocXUIRgBB8GREKbdRWi1BEIQ2V51gKFfUcwpO/FSYhdM1Z+VtFPQmCILQpkMu0eobhxAD\nTF/uElxYf35zk78lCIIgJDqkhWF3qK0E5QJJvEYMQwnFMAiCILTpkIJh8xGMBpt3riilteHUEAyK\nYRAEQWhzVQjGmeIa+TOvimFoWRgUwyAIgtDiqhCMQEuwarukwrP6H0EQxJVOBxUMtRAEWoJV2yVF\nFgZBEIQWHVQw1I0+H3Achv/qehT0JgiC0KaDCoYLUIzuFgNYGC5Rw8IgwSAIgtCkQwqG3cEDjFIk\ntGMYTp7SagmCIEKlQwqGwyUAjLZIBLMwRBIMgiAITTrkSG8XLwCst+G3O1348PtjiE7LwxnLGXm7\ndlotCQZBEIQWHVIweF5UuaTAiDhwthCmmL2q47TSaimGQRAEoU2HdEm5BAEMq8iU0lgXQzqOBIMg\nCCJUOqZg8KLKJRVl4jSP0xyHQQP3CIIgNOmYguESAIWFkdQpAso0WwBgwcpreitRTlRIEARBeOmY\ngiGIYBQWBsOIqqwpAICokxdQUkIuKYIgCG06pmDwagtDSzAYgdMc6U2CQRAEoU0HFQwX2Ngq+ftF\n60UweofqGJHnNGMYNA6DIAhCmw4pGMW6Q9CnnJe/23g7jDftUx0j8tqBcAEU9CYIgtCiQwqGhSsO\negzv1B6CQhYGQRCENh1SMBhRH/QY0aUtGJQlRRAEoU2HFAwI2u4mJSIfQDDIwiAIgtAk7IKxbds2\njBs3DmPHjsXSpUsDHrdx40b07t0bJ06caPU1GSGEGU98BCMlojtEgYFIMQyCIAhNwioYgiBg/vz5\nWLZsGdavX4/MzEycPXvW77iGhgZ8+eWXGDBgQBtdOLhgKC2MCFs3PNzzdwAYimEQBEEEIKyCkZ2d\njdTUVHTv3h16vR4ZGRnYsmWL33GLFi3Cf//3f0OvDx57CIkA7iYPDBjVMRwM0LEcIDKUJUUQBBGA\nsAqG2WxGSkqK/L1r164oKytTHXPq1ClcvHgRI0eObLPrBkqZ9aBnjBAVcQ6O4cCyDCAyECnoTRAE\noUlYpzcXg0zkJ4oiFixYgDfeeCPk33hITIwJfN4gv43QG1GvEBW9jkOXztEAGDBM0+duj1xp5Q0n\nVBdeqC68UF20DWEVjOTkZJSUlMjfzWYzkpKS5O8NDQ04c+YMHnnkEYiiiIqKCvzpT3/Chx9+iL59\n+zZ57vJyS8B9vMYcUUoMPhaGKDCorWkERAa8wDd57vZGYmLMFVXecEJ14YXqwgvVhZfWCmdYBaNf\nv34oLCxEcXExEhMTkZmZiYULF8r7o6OjsXv3bvn7I488gpdeegl9+vRp1XWDTVEeoTepUm85hiWX\nFEEQRBDCKhgcx2HevHmYMWMGRFHE1KlTkZaWhsWLF6Nfv34YNWqU6niGYUJ2STVFsLEUPWJScFpp\nhIgsWIaBKFJaLUEQRCDCvkRreno60tPTVdtmzZqleey///3vNrmmCAEMgFE9huPnoh2qfcmRSXjg\nhonY+mOm4geMZGGALAyCIIhAdMiR3h6X1K+6DfXbd0+vX8PA6QFVJhUDTnZJkYVBEAShRYcTDFEU\nZSuBZfxvz7NNFfT2WBgkGARBEAHpcILh4r0NPsswfvs5t2BE6k3ejQIjHSsyCJ6USxAEcXXSAQVD\nkFfXa8rCeO0Pw+Vtouh2SVEMgyAIIiAdTjB4wbscK6NxeywjuaJiIg3ejeSSIgiCCEqHEwwXL8Dj\nVmrKJaVEEL1Bb9+1vwmCIAiJDicYgsLCYBkWg5NuUe3XEozrr4kiC4MgCCIIHU4weEEEoxCMGTdP\nx8ged8j7PS4pADByklsqMtojEgxAMQyCIAhNOpxgCILXQmDcLikd6xUJZSA8UhcJAGhwNkrHgwEY\nWnWPIAhCi6CCYTabL0U52gxl0Jt1356eUax9oRCMLhGdAACe2Ug8QfIXtr8SdAJDgiCIq42ggnH/\n/ffjz3/+s2qSwPaMSjDcFgYXwMJ45KbfYHDSLZiUNs69RTre6rLB4qy/NAUmCIK4QggqGFu3bsWY\nMWPw3nvv4d5778XKlStRX99+G1PJJaUeh6FjvRaGUjA6RyRgxs3TkWCKB6BOw9Uaw0EQBHE1E7RV\nNBgMmDx5Mr755hv885//xCeffIL09HTMnz8flZWVl6KMzUI1DkNDMJTWhi8svGm4bTFrLkEQREci\npG50cXEx3nnnHTz77LO4/fbb8emnn6Jz5854/PHHw12+ZqNKq3ULgC5ADMMfr2BQ4JsgCEJN0OnN\nn3jiCeTl5eGhhx7C6tWrkZCQAAAYNGgQNmzYEPYCNhdeUA7ck8Qh2hAl72/K1cSChSfUzZNgEARB\nqAgqGJMmTcLdd98NjvN35axfvz4shWoNXguDkdNqY/TR8v6mBINR7BNEypIiCIJQEtQlFRcXh8bG\nRvl7XV1du86Y4kVp4B6jcC/FGryCwTGBYxgMuaQIgiACElQw3nzzTURHexvc6OhovPnmm2EtVGvg\nec/Eg97GP8bgXfi8qRiG0voglxRBEISaoIIhiqLs2gEAlmXB8+3XXeNxSSlTZCN03rUvmnRJKUSG\nJ5cUQRCEiqCCERUVhaNHj8rfjx49isjIyLAWqjXw7nEYysZfJXhNuKRYKGMYZGEQBEEoCRr0fv75\n5/Hkk0/iuuuuAwCcOXMGS5YsCXvBWoogui0MjanNAe0pzz2og94kGARBEEqCCsbAgQORmZmJI0eO\nQBRFDBw4EHFxcZeibC3CM3CP9TGeJl47Dudq85sMeivFhOaSIgiCUBNUMAApU2rkyJHhLkubIGi4\npABgbM/RQX+rtEoo6E0QBKEmaAwjJycHv/nNb3DLLbfgpptukv+1VyQLAwFdUk2iWG3vqy25KC5v\nv3NmEQRBXGqCCsYrr7yCZ555BqmpqcjKysLMmTMxe/bsS1G2FqGVJRUqjEIwiivqsXxDTlsWjSAI\n4oomaKvqcDhw++23QxRFJCUlYfbs2di+ffulKFuL8GRJNRXcDoTqJ4wIp4vcUgRBEB6CCgbLSofE\nxcUhJycH1dXVKC4uDnvBWopnidaWWBhKlxQY0T0vFUEQBAGEEPTOyMhAdXU1Zs6ciWnTpkEQBMya\nNetSlK1FCK2yMETVZ8+ocYIgCCKIYAiCgNtvvx0JCQlIT0/Hvn37YLfbVVOFBGPbtm1YsGABRFHE\n/fffj5kzZ6r2f/3111i5ciU4jkNUVBT+8Y9/IC0trWV3A/dstYzYsgWQfFxSZGEQBEF4abJVZVkW\nf/vb3+Tver2+WWIhCALmz5+PZcuWYf369cjMzMTZs2dVx0yYMAHr1q3D2rVr8fjjj+O1115r5i34\nXrMVQW8oXVICXAJZGARBEB6CtqppaWkoKipq0cmzs7ORmpqK7t27Q6/XIyMjA1u2bFEdExXlXaui\nsbFRjpm0FHngXossDJ8YBrmkCIIgZILGMKqqqjBx4kQMHjxYNYfUokWLgp7cbDYjJSVF/t61a1cc\nO3bM77iVK1fi888/h8vlwooVK0ItuyaC6Fk8qXXjMMCI4LucxvGKLri5S/sdd0IQBHGpCCnonZGR\n0aKTh7ou9vTp0zF9+nRkZmbigw8+wOuvvx70N4mJMZrbjUY94BCh1+kCHhMInY4BXNJnhnMBKXn4\nMDsH3/7mw2ad51LT3PvsyFBdeKG68EJ10TYEFYwpU6a0+OTJyckoKSmRv5vNZiQlJQU8/t5778XL\nL78c0rnLyy2a2y31dgAiRCHwMYHgeUWQW+cIeq32QGJiTLsu36WE6sIL1YUXqgsvrRXOoIIxa9Ys\nzWk2QnFJ9evXD4WFhSguLkZiYiIyMzOxcOFC1TEFBQVITU0FAPz888/o2bNniEXXRhBEgG1ZWq3I\niJ7lwMHonK0qB0EQREcjqGCMGjVK/my327Fp06aQ0145jsO8efMwY8YMiKKIqVOnIi0tDYsXL0a/\nfv0watQofPnll9i9ezf0ej1iY2PxxhtvtPxuALhEAQzT9EJJAcurmMmWBIMgCEJNs11S9913H/7n\nf/4n5Aukp6cjPT1dtU058E+ZttsWCO6xEy0RjP6mdGyuKwTD8SqXFEEQBBFCWq0vDMO0OM32UuCS\nBaP5LqkYXRwceYMBAAwJBkEQhIpmxTBEUURubi5uv/32sBespbTGwuBYBhCleyWXFEEQhJpmxTA4\njsOMGTMwYMCAsBaqNXim8+BaKBiiRzD0JBgEQRBKwppWezngW2Fh6DhWtjAIgiAINUFb1WnTpqG2\ntlb+XlNTg+nTp4e1UK3B5V5atSVTjBgNHCD6/47W9yYIgghBMBobGxEXFyd/j4+PR319+126VGiF\nS8pk4DQtDJdIgkEQBBG0VRUEAY2NjfL3hoYG8Hz7bUB5d+POsVyQI/0xGXTagiG4Wl0ugiCIK52g\nMYzx48djxowZmDZtGgDgq6++wsSJE8NesJbiEqV0WBNnaPZvA1kYToEC4ARBEEEF449//COSkpKw\ndetWiKKIhx56CJMnT74UZWsRLlFq3I06Y7N/azJwUK+i5D4nWRgEQRDBBQOQMqWulGwpwT3dbEst\nDJFcUgRBEJoEjWH8+c9/Rk1Njfy9uroaTz/9dFgL1RpccFsYXPMtjEBZUk4SDIIgiOCCceHCBcTH\nx8vfExISUFhYGNZCtQaXKDXuhhZYGByrPQ7DwVMMgyAIIqhg8DyvyopyOp1wONrvPEs8pLIZWyAY\nADQFw+4kwSAIgggawxg+fDhmz56NRx99FACwYsUKv9ln2xOeGEZLLAwAmoJhdbZfgSQIgrhUBBWM\nOXPm4OOPP5aXTR01ahSGDRsW9oK1FF6OYbRUMPyNLhIMgiCIEFxSer0eTz31FN5//33cdddd+OGH\nH/DXv/71UpStRQiMZGG0JOgtwUAU1NXicJFLiiAIokkLw+VyYevWrfjuu+9w5MgRuFwuLFu2rF3P\nVutxSbXYwgAAgQNY7/reNhIMgiCIwBbGa6+9hjvvvBNff/01xo8fj6ysLMTFxbVrsQCUFkbLBOOv\nvx2MCL36txabDau3nUNtA7mmCIK4egloYXz11VcYOHAgZs6cidtuuw0A5IWU2jMi4wKDlge9r+sR\nh+hCI2yoDujAAAAgAElEQVS2BnnbgbwSVJ5lcMFswdMP3NJGJSUIgriyCCgYO3bswLp16/Dmm2+i\ntrYWkydPbteTDnoQWUkwWh7DUE6NzgAQYbHbAACVdfZWl48gCOJKJaBLKjY2FtOnT8fq1avx/vvv\no7a2FjabDdOnT8fXX399KcvYLES29TEMxl0tBkYSHY+b6wowsAiCIMJGSItG9O7dG3PnzsX27dsx\nffp0bNmyJdzlahGCKIJheUBkWjS9uQfWrQwGxgQA0jmhNS0hQRDE1UNIkw960Ov1uPfee3HvvfeG\nqzytgudFgBHAiC0XC8C7vKue1QMCAM49lxQpBkEQVzHNX5auHePiBYAVwKJ1gsG4lUHHSHoqWxjk\nkyII4iqmQwkGL7gtjFYKhsclpWP07g28e3urTksQBHFF0+EEg2EFsK10SXmC3hzLQRQBcLy8hyAI\n4mqlYwmGxyXFtI2FAYiAwIFhKUuKIAiiQwmGy+2SanUMg/FUiwjwOtnCIMEgCOJqJuyCsW3bNowb\nNw5jx47F0qVL/fZ//vnnyMjIwKRJk/D73/8epaWlLb6WZGHwbRb0FhkRosAp0mpJMQiCuHoJq2AI\ngoD58+dj2bJlWL9+PTIzM3H27FnVMX369MHq1avx/fff4+6778abb77Z4us5XTwYBuCaly3sh8ol\nxXNyWi1ZGARBXM2EVTCys7ORmpqK7t27Q6/XIyMjw2/Q39ChQ2E0SiOqBwwYALPZ3OLr2d1LqbY2\nhqF0SYkC586SEsm+IAjiqiasgmE2m5GSkiJ/79q1K8rKygIev2rVqlat5md3L3TEtTbo7XFJuWMY\nDAMpXZdMDIIgrmJa57sJgiiKIR/7/fff48SJE/jiiy9COj4xMcZvm6myAgBg1Bk194fKr3oNQk71\naQxM6Y8LxYekjRwPg0HXqvOGi/ZYpssF1YUXqgsvVBdtQ1gFIzk5GSUlJfJ3s9mMpKQkv+N27dqF\npUuX4ssvv4Rerw/p3OXlFr9tFVW1AABRYDT3h8rAuIH429Du4BzR+F48LG1kBDidrladNxwkJsa0\nuzJdLqguvFBdeKG68NJa4QyrS6pfv34oLCxEcXExHA4HMjMzMWbMGNUxJ0+exMsvv4wPP/wQCQkJ\nrbqeJ4aha/U4DBbdopOh43SAe7lWhlxSBEFc5YTVwuA4DvPmzcOMGTMgiiKmTp2KtLQ0LF68GP36\n9cOoUaPw1ltvwWq14umnn4YoiujWrRs++OCDFl3PyUvZTJ45oFqLjmMA0T3V+fWHwdti2+S8BEEQ\nVyJhFQwASE9P9wtkz5o1S/782Weftdm1HG4Lg2Pb5rY4jpUFg42yoNywF0DLg/IEQRBXMh1qpLfD\nbWHo20owWAai4K0igXG2yXkJgiCuRDqUYDgFdwyjDQXDY2FIUAyDIIirlw4lGK42jmFwHCMHvQFp\napAfC37Bq3sXwim42uQaBEEQVwodSjAcotslxbWNYLAMA4hKq4JBTtVplDRcRJ2d0vQIgri66FCC\n4XIHvfVsaGM5gsEwjHoxJhGwOOsBAA7B0SbXIAiCuFLoUIJhd0mCYdS1jWAAUM98KzKoc0iWhZ23\nt9k12gKn4EJ+XWGzRtd3VBy8EwV1Fy53MdoFdt6Bwrqiy12MdoHNZcMFS/HlLsYVTYcSjEaH1OuP\ndk9m2BawiioSIaLe0QAAcPDty8JYeWoV3jqwBMcqTl7uolx2lp9YiTcP/At51WeDH9zB+fDocrxx\nYDE1lADePfQRXt+/CGWN5Ze7KFcsHUowrC6pEY+JMLXZOZUz3wqMQ5qQEFLPrT1xwCxNYZJPPWtZ\nNIuokcTpmnMAgNKGls8C3VEoqpemKaqwVl3mkly5dCjBsDsll1SbCobCJcVzXjdUexMMz7QlHkEj\nQDVBaELPRcvpUIJhc1sYkQZDm51TKRgi6xWJ9hbDoNUA/SHx9EKxLSVUFy2lQwmGZ2qQthq4BwRe\nW6PdWRju/6lhIIimoXek5XQYwRBFUZ58sK3SaoHAq/fll1Uj70JNm12n1dBMun5Qw+CFrC0vVBct\np8MIht3JQwAPoO1Gekvn0haMPaeK8frKQ6i3to/5pbyrkNPLQPhDT4UXXhQudxGuWDqMYNRbnQAr\nPQhtNdIbCOySYlhJnBrajWA0HfQ+W1KLV784gGqLN/bCCzyWHPkUe0sP4oKlGG8fWNKhMkiaI55O\nwYVFh5fiUFk2ztcW4O0D76PGXhvG0l1axGY0knbegXcPfYhjFSeRV30W7xx8X04nvxQIogi7kw/b\n+V3NmNan0WnFOwc/wKmqPJyszMXCgx/A6rKGrWztnbBPb36psDl4MIz0UrRpDIMNsBgTJz3Q4Xyw\nm4XHJRWgjVy8KhuWRic27C7A9LtvAACUNJhxqioPp6rykBLVFaUNZvxw9j+YcfP0S1ToMNOMbvXZ\nmvPIqz6DvOoziDXEoM5hwab8n/GbGyeHr3yXkOY0kscqTuJMzXmcqTkPlmEhiAKyincho9ddYSyh\nl7e/Ooycwhosff5O6Li279M2Zx64/ebDOFebjyVHPpW37bt4GCN7/KrNy3Ul0GEsDLuDly2MtnRJ\ncYHO5bYwbI52IhhuAvWqBcF/u15DDHmxfd1Pa2iO60HZyfDUYUeqi+Y0klodLkG4dHWRUyjFBsP1\nbrmE0L0CWi7pjvRcNJcrVjB4gVdlKtkcPMBIf8i2Wg8D8HdJiU4pZZdxWxiOdmJhsB6XlChqZnBJ\n8R0RDCOZ2YBk+su/Z6RHwdPIWl02CFe4r9cpODVH5Dt4h3xvnrpQ/p09nwW5Lqwdti7s7roQRVGu\nC2XSiP9zYb1kyQROV3jq3Mlr14XNZYcoihBFUXY76Tn/BBrls3O1JVZcsYLxjz1vYU7WXPm7zeEC\nWAEMmMBupBZQWmFTfRftEdIHzuW+bvsQDE/Y++eiHZiTNRfnawvkPQ3ORqD/f6C/NhvFTDae3/4y\nTlXlwSV6e52cu2EQRAE2lw3Pbftf/Ethhl+JbCrYitlZc1FcXypvq7XXYXbWXHyduwbrz23G89tf\nxrnafFWvkVXUhcVRj+e2vYxPjn1xycvflqw7twmzs+aqpsWosFZhTtZcrD2zAavPrMfz219GcX2p\n3PkAvB0RQRRQbavBc9texoqTX1+SMjtd4Xm3vjuzHrOz5qLa5s1yLG0w49lt87Dh/I/4Kvc7PLft\nZZQ1VgSwtgRcbCjD89tfxjd5a8NSxvbKFSsYFTYpOOtRe5uDB8MK6skC24CGRvVDK7oMEGwRYE31\nAMIbnGsOvlm1+y4ekj97gre6LqUoZrIBAEfKj4MXlI2kt1dd655gMa/6TDiLfMk4XHZM/lzWWAEA\n2FmyF//J/wkAcKIy16cuPL1qHuVW6fjsihOXqrhh5WRlnvzZM1XGlgvbsPXCdgBAXvVZVUdCfi4g\nyPNR7XdPQxNuHGGyMDycrc33fq45DwDYkP8TdpbsAwAU1F3QTBYQIOCc+7fbi3eHtYztjStWMDxY\n3Wa05JIS2jR+AQBjh/RUfTewRojWGDB6J6B3tBvB8A1dWF1ey0hpNovuxAAWjCoQyroVhxcFzTHj\nVrsLLl798giiiCWrj+GnA+17/iqboi60rE8GjE8j6XXvNbXKYqHZglc+24eKWv+smUZb+8ie80X5\nXJg4/0k6GTBwKcRTaXkyl3isTyCXlCiK+GDtcWzcW9iq8yvrItYQ47efYdR14UFyz12d456uSMHI\nyfemflocVjh4J/KteQArgGvD+AUATBl+neq7iTNBaJQeLjbCIgXbFdTYa+XJ7wRRwKGybFUansVR\nj1NVedCiuL4UedVncaTsWLN8o5W1NjTa1eVQvgzKxtAFyXfr+zJ4PgsiD9+XodHmxJPvbsPSH9S9\nbEujE4fyynEoT3v2zypbNU5U5gKQYk4HzUdhc3nTemvtloBWTKGlCKerzyK7vPU9e1VdaAR/pbrw\nbmfg9ds3NeXK+2uOodBcj9Xbzqm27zlxEU+9tx0HcsrkbRXWSvnv7hRcOGg+qvKj19hrcbpafR4P\nBXUXcLr6LI5XnGrqNkPCynufRa26YBnfjoR2XRRcbPkCYubGcvnv7uAdOGg+KgflRVEEY7CCja7W\nFIzztQXINufhUOkJfPtz6yxgm+od8RcG346E5/4FgW8z8SxtMOOM27qxuew4aD6qsnYrrJU4X6st\njGdqziOv+ixyqk63SVlC4YoUjL9+uFP+XN3YgG/z1uKQYyNYU2ObWxg6nx5plF4hGJEWPwvj9f2L\n8FH25yipv4jdJfux7PiX+PLU/8n7Fx78AEuOfIpCi/8aBQv2vYtFhz/GJ8e/aNbU3BfK6jUsDG/D\n4BkBr4QFq3oZPI1XlcXmZ4afKZZcWgdy1cJQ4x7TESiO8/c9b+GDo8tQbavBz0U7sPzESvzf6e/l\n/Qv2LcSiw0s1p5t+Y/9ivHf4Y3x8bEWr13NQ1YVWIwlWJZ4WmycpQPBz9YmiiHU7z+NMUS3sTqme\nDDr1M/LzYcl1s+Wgt9wv734DS458ikanFRvP/4TlJ1Zi3blN3v27Xsd7hz9Crc9KjqIo4s0D/8J7\nhz/Gh9mftXpqbmUjqVUXDMOqGk9PxpggqJ+Jv3++v8Vl+Meet7Do8FLwAo81ZzZg+YmV2FzwMwDA\nxQswDciCsc9eNDjUlptLcOHtg+9j6cllMN54END5z+dmd/BY+WOearxRXmE11mzzf58aVe+Iv0XI\nMqymq9IpuNrMvvjn3nfw7qEPAQBf567B8hMrsU3h5np59xt4++ASv6SLRmcj3j30IRYd/hj/OvLJ\nJZuq6IoUDGXPo87WgNOKxrUtB+0B0kMyb9hz8vdoYyREqyQYTKRFbixdvIB6qxMWh7QiX3Fthezz\n9UwxDQBlbp94sCVeq+yhTzui0/k/vqH1qr0vg+eBq6hthM2pfnnyA/QmaxuklzKQW85z3TqHBWer\npSC8Mhhf75QGgzU41Q2D78tR66jTPH+oWF02rNych882nNJMqfTtVTfYbe5y8H6B1+LyBqzZfh4L\nvjwoZ8gZ9OrXyKCXBETLB2912WTfuXKRJ08jbePVSRYOn/L61lVzaQzyXPi6Kj3PhUt0+fWqeaF1\nMQYbb8cZ97vhSUxQdj58BUP5TAPeTEUPdQ0OZO7Jx5aDRVj0f0fl7c8u2oZ1u8/7XT+QFe5h+9FS\nVV14EiPsgqPNJ/t08k7kVkuWwoYjx5F1RD01v83n3n3rwsE7UFzRgL98vBtF5fVtWjYlV6RgKLHY\nG1UpkW2ZUushOSpJ/hylj8DYW24ECx3YCIvcaHy49jhmLdouH/fxumMoqZL+cKFYPXaXuofg8R2H\ngsMp+LlUg70MDMOAV7wMckPFiKi1qh/G/FJJMDrFqn3eNfVSmW0OHrwgYM/Ji35xDkBaCfHwGck9\no5V14tuI+84EHGg+r1BpdNmw5VARtmeXajaSosioc+vdDREvCigqV4ulskHzCKVRry6fQSf97Rwa\nWT523i5fS6su/BsGdaNpaXTg2LlKv99poZUKHNzCYFS9aqf7byNZoL6uytDHdmhhddnkZ9Mz3kHp\n4m1wNt1IesZCAcCR0xV45l87sHmfJMLFFT4j0xn/uvB07gDtusg+VwGHhnVudzlk67qtsPI2ud7r\n6l1YsTEXh097rUnfe7f5vCN23oGVm3NRVm3FF5ty27RsSq5IwdD18FaIxdaA+kbvw6CVN92WROhM\neHDUDegW1RVMRD1+OVqEilorjpaehq6HNzbBsALqGqU/slag1bdhqKhXP+AsGDgFFz47ugqf/LRb\nsyH2YLW7/F4Ia5CGgQULp6KR9BzPRtXhZIXXJ7rmTCYuOiWrwGRQN3A19W4Lw8Fj874LWPrDScx8\n6xdsOH4Qmec2y8eVVNXJ5dOaasX3ZWj0szh4OHgnvsldi4sNZWgKrdiPxe6tW6tTYyyCg1fHMNwD\nQE9V5aHI6vUfr8r7AWdqvdas51K+o5E9FobTKeBkZS425W/1Xt9lky27UOrC9zn5dMMxvLvqID49\n8g0qrE0Lh1bA1mPVSfv9rS1RFDU7GEfKjyO/zlsX+tSTOFUh1cWBnDKV++1gbjl+3O+fCHGs4iR+\nKsySv1tdVjhd0rUEQRIjq8N77UZno+r3vnUBlkdlfR1e+XEZ1uyR4oYeq473HajKagmGtzOgaYWz\nAup8Ok8AsN98CHkV3vv7Kuc7Vd2UVjbgi025ftZpWXUj/r0pF1a7C4fKsvHLBa9r/bON2bB6LHtR\nqot/rT4i728M8o44eAecsELf8zhELnxLL1yRU4Pou3nNy3qnDQLPyHdivASCAQBdIrqgqKEYjN6O\n1VnnYOyzV30gy0trjBu8vSdlY2b1cT1U1qt7snbeiUMlOThQuQ+uslSkHb0Gowf10CyTJ0NMtY23\nSWZuYR2+2HUSSPH/DR+h3UPcXu5t4H4qzAK6AigY5zdI0WNhNNpdqpl7M8u+UR13zlwll88TE1L2\nfn0byYp6tUm9YvNJpPevwbayXTDpjJiUdo9muStrrfIU90osTgukIA+Depv/y9RodyEmUtuttqdq\nm/z556Id7k/jVMf4irnn71xRa8P7R5ep9lldVtnC4FgOlbU2REcxiv0+DYNfQ2EHl1CPw1XH0MPc\nBeN6jtEsNwAsyzwORKu3VVqrpMCyjxvOQ1b2BQy6qZPm+Tac/1H+rOtaiBVnPsOGLVNRaJb+XqMH\ndQfDMHh/jZTGPGZwD7Cs994+yv5cdT6ry4Z6uwNggNOFdUB/oMZqUe1X3buPtcWwPD7ZsQXlhlw4\nXAyAnprllg72FwzlvGmaU6cwAsw12nNo5dgOyp93lOzFjpK9eHXofPyw4zx+OSKlK3dPjFK9s59t\nyEHuhRpwLINd3Jeq82UXmGG8yQWGAURB6oAwesXAZJ9793Vd2nk7GkwF0MUXwVbnfdkdTh42B4/Y\nqLZZI+iKtDCUNDqsqh68oY1jGL6Y9JJbJlLvXtWP5REVoSFSnEsSDHhdD0pXi+/LUNWgbiR3nSzC\nsi173OdyykG8eqsTe0+a1eJjd4Jh/XvWFxvLcPR0pWYjeaGsXrMH2hS+PnmlWX70bODe7skLZsBd\nPo7hsOt4KaoavHEJ37ow16on/attbMSmbClbqqC8StOKKDRb8Ng/NuOrrf7muAAejEnqrTbY/evi\nXGk1LBo9yVBxOH3E2uHJOPMvp2RhSI2Tpd6F5z/chR8Pn1HtBwBzdSP2nTLLaeMeGJYHEyk1qr4N\nqBKeF7A/76Lf9kaXVR6Xo2V55ptrYXWEHkD1iAXgH8uqD5JabHVZIbpnZ6ixSGWptgUWDD8Lg+NR\naZcsTkbX9LUYDcGoddTJFpfm1CmsgBP5FU2eV8mcJTtlsdDC8zyczPef4JPhXN53WHQ3y3rt9sLh\n5HHkXKny57DzDjj0UqdNUCz09q/Vx/DMv3bgfGmdKmuvpVz5guGywil6K6gt18LQwshJSm3SuRWb\n42EyaOT2cy44eG9PEpCCvx58X4Zaq1owTpdUgYmwyOdqtEsP9NIfTuDjH05g9wmpMcgvrcPq7drp\nhcX1pdLvNF4WvV47ttEUdQ0OrPwxT0468AS9g/7O1ii/sJW1Dny6/hS+yjou71fWhdXuwsqtJ9Un\n4FzgoqX6OVFYhiNn/F/i00VSI7gtW3sdb09dNmg0hgVldfjpYMtz+n1dD02N/le6pKrqpEYu6+R5\nxX4raurtePXfB/HR9yewJ9fnfjgXWLdgWJ2BRa6i1ia71nzxBJg1G0lGQKWlZTPT+sY0ausd+H8/\n5uGERgMJSHUhwlvGaosdtQrBsPE2rN1+Tvbl1/uJpwt2TmokGc7/XmwOlzdBJkBdlNRL75FTa34p\nRoCIlo+z8nWLJcRInc3Sykb/g5Xld78rSgtD+Y5k7i7AtuPq5zWvuAINkOq5rM6CI6eld+TEeWnb\nu98exQdrj6O1hN0ltW3bNixYsACiKOL+++/HzJkzVfsPHDiABQsWIDc3F++++y7uvvvuZp3/jOMw\nlIO723KmWi0i3BaG0T3oSZ9yDj/VHQcbpT5Ol3Ieojt4esFSjKd/fglDEofJ+60uG+oaHWAZBrwg\nYt3eMzBerzgB620YuE5l2C0sR3XW7TgbfxCmgQJOFMUhB1uxP6cM+l7aZf3i1LcwGDqBYRP99rkE\nV7MsDOMtv8B+dCS2HCxCgXgQxdwRONleAG4AIMLQex+E+ni4im70+62u+xk5o6VKLILp1hKcr0sD\nIjx14W0IDuSUQWDULy/D8oBJskh0XUqxrPAd3G8aj435W8AyLP46dDZ2NayDvpcDjF5bxIzXHwFf\nlwCro7f/TlaA3eVEqF0NY7/tsB8bIZWnRy726jeh9KfBEC6m4flpt6C002bohAS4iq/3++33ZzfI\nAcs6w3mYBhfAXuP9A1pdNhw7Vymvs5JbXA50UZaVBxshieeu0n3YZz6EKddlIPPcZhg4A/42dA6W\nHf8SUboEGK7L0Sz/h9mfoW/n3ugWlaxZF1WW0DOxjH13wX5Cmrn1q1NrcarhMLiuN4I398T6PeeQ\nzfyAnTu6YX6nh/1++395P4CH2wrvWohX9v8DaRE3y/uPnr+Ig6WRAIDlfxmN/Xm+4snDZagFA0CX\nXAAu8QKcF26EvsdpiC499uddj12N30N3DQuuk1mz/IsOf4wBXfrByET67WMYIaDQaGHovReOHOkd\n1/c6hu/rNsFUPAUjut8GO+9AbsT34JK6gS/7L//f9vI25vpu56HrWgi+3OvO2pNbhBuib0ZCjBFl\nNVY/gVy39wz0PS3u35/D0vx34Np+A0yD8yA6TKg/fgdMhtZ3psNqYQiCgPnz52PZsmVYv349MjMz\ncfasOh+6W7dueP311zFhwoRWX++G+DQMS7m11efRwp43CC7zNegaIWVMeSwNrpMZbJR/2qlvyp9L\n5LGr2Ju73ui04pnFO/Dy8n3IL63zewAYnROMydvTY1gRObb9gN4ORu9Efv15HCw7CrZTKXRd1Oap\nEoe+SvOhd/Au2TUi8sEfA9Zok3tBhY35ACOC6+R2eXAucLHV0Hc7j+t7xPn91rcuGFaAI8abXqvs\nPVXX2wHORzD0dvA6b12IjID/nP8J9c4G1DksyK0+g4uu89AlFoOLD+xC4GKrNd0tyoYhpLqIaPAG\n8eMqAUZEfmMeThfV4qKlBi5jFfTdz/pllQH+2S0MJ4CP8/YWyyx1qKhRWFw+bifWaAVj8J7DJbiw\n4fyPsqsppyoPOdWncbB8H9jowOnIJypzAvjtRdS6kzVEPnh2GhtVB88goOOVORAhgkuQGucD5/Kl\n/d1y8NyHO/x+6+uHd8GJHIuiF+zzHJRb1O8ZG1EPRqdIVuAE6LufAaNzgTVZsWLndhTWF0Kfkg/W\nGFgEj1Qcw84TGpYpKwKMdG+iEDyNlout9n6OL4MIUR54WmQpgUNXC0PPU2prwlN2nc/7z/HgOnvd\nWycKzXj2/Z1wuni4XALg44JjI+tUbmmG4+WOGhvRACbCgi5xEUHvIRhhFYzs7Gykpqaie/fu0Ov1\nyMjIwJYtW1THdOvWDTfccEObjJx8etAf0bezfw+3LRBqkuAs6CtnwHgEozkweu8fuapBevirLXZY\nGp1+DxEbVecXl1D+vhL5Aa/j55Zj/R9QycJwC4ZDepBEgYVgjfI7Vr6+u4zy/6ZGgOFVYhdhCvhz\nFSKnMLcVDcfFqkY/8WSj/RcyanB5zfpD5qN+++XruNQWp+8YEwAAI4DxNAzuuhDsJgj2Jm7GU0Z3\n3TKR9QAEfLjOm9kSGer7qXj59+YUyVONdIo1wi74pBhH+4/PaVBkEy3d/qPffg++dRHIDVNRJ51P\ndEj3LzTEQHQ10Tt1p7d6FhWTLGNRTk8GoNlIaqGMRSifA0EQYfcRGK26UL4jus6BO1IRnM/fVsuS\nUFoYLkn8+drOclBaGxFRJp2cgexx/ymTPIb01U4q8Lu84l489feXj/fgTEmtxjvSdF2wkRYkxof4\ncjZBWAXDbDYjJcUbse/atSvKylofeLmceATD0ALBUFLeUA02tgKMwYq9F05IvVYFWo2kEq5T4Hrk\nGE41sEjrYSqutOBAgWTteRoGhhUAIXCvkjFYwRisgLs3xDAAl1AGNtrbs7JFNX9UdqW1CvNXb8Ci\n7/di/4WT4CLUPt5gdXG0iYkBRd4nFVjQqDfOG0iW60LnbLoujFZAb5N7hgwrgI2rQJldEfSMD9xg\nBTyvqRF7L5wAY7QisUcDGGPz6qKp50J0qZ/Z0gZ/N01kBMC4XV6eqfzBuZq0Njx14WnUGJ0LbEwV\n2Cjvc+exOpoDY2oAG1sB6G34n6XfoYFXW0zB6yLwNRle3XhKk4mqiYpSbHd5ljVwAU3UBWu0Ys70\n3rJ1VOuoQ2FdkSrtNqVn8wfWsRH1YGMrUG2rhYUtAWPwFc+m64KNq4CDa/l0Lh7C6vC/lHPF63h/\nH2RbMvvBW3DsXCW6dZauY9SYuK051KMKxt5SQOocAF2IvVHRqYco6Jo0sR28EzpGB6coPbRKU9mD\nrvNFeLZ6GkmgaTeEsY80i6eyh2W4Tt27v2DcieZSaCkC4otwEYBBI8QQCMFuAsMKquBgMCysf+aQ\nrou3kRed0t+V4XiITQiGqa80fYPSVWG88ZDqmMq4fSGXywMXUwOu9wEAQCEAXXTTx3sQbJFgdA4/\n14b6IHX/UDlbqwdnfL43JOhpJHUur3hoYOrn/zc33qSeOsTQq/lzgnFxVeDimr9ksGiNhj7Kpu1y\nc+OwsYBCM7Tcd474s3KPWnAYwEUC0DkhChwYaGdlGW/ZhrePbVNNKfPGgcWqY368+J9Qb0WGi69o\n0t0aiO7RKSiuvwhd54vIRyZ8U8KbS1gFIzk5GSUl3pfRbDYjKSmpiV+0DMFuwi2GSUhM9J9xsq0Y\nnRiD0cN6yt+TXP6+eiXOwhsh2KJgvEFqRETROwW57sKtiO3sRKUuR+WPBgDHuZthuDZwNoNgjQFn\nvtSov7cAACAASURBVAlOUxn013gH2Ik8B74mCbrOpRDAQ+ARsv2oFIymetUeGFaA0BArWRz6ptMZ\nAcBR0Bs6PhrstVJDqKyLJ2/9A1Zs2QNLVK5PaiQDR/5NMPQ86X9CNzpnHKyFqTAm1AIp3nRa0aUH\n6juDib8ITic0Z6VWVcNo0hkRTIoYVgRviQcXaVG7YLTOLQLOwt4Q7ZHyc+EeHgJRBBynByE6wQpH\npzxV3EfkWbiKr4f+vwKP4BUaYjGwy63o05fFN8d/kLcnmOJQYdaBi6tE5zgjqp2h9269FoYTsIXW\nIePrOoGNqZZdfAHPLbBwXrgRoj3C7x0xsHpYcm4GG1EPXfezqmwv0aWD62JP6HtImYEMGL9VJjl7\nAl4YOxmL1/2M+jjvuyTYIgBeDzaqDnanCF+vVJPlVVgYTYmnqhwNSeCjgntUInQmTOs/CZ0i4vH2\nzo+l67nrIkofiT8OmY4SixnfHFunulfRYYSrops8Lk0UWL/MuL7J1+PhpEnIyjmF265XT6TaEsLq\nkurXrx8KCwtRXFwMh8OBzMxMjBkTeKBRSy0SobYLOhvjUV5uuWT/rPWKF9qp9u/ytZ1xW9KvINR4\nxVGwJLgPZlB/sTMcRb3AOvy7jwZLapM+UrExBjd06olZI+4DX9VV3j6g8wCIivhDc6rSxHobA18X\nTsByOI3ISNUeQKc6t7UbeHNPRNq7y9s8daFnDDh+gEV5bg+IPg0SyzB444HfeK+ncT8pkcmIZ1Ng\nvdALfI03lchV3h2dTPEAgGYnhijcNjGmEM0+pxGdGwcGPYyvSgZv7ql6LroaesjXFWqSMOPWCUiO\n7uLzQz1cZdd4v2uIumiNwU0JaUhPGo4Ywftc3J48RLaaWK5575fcSDJo0tpS/cZuguvCDUGP4ytT\nwJtTVXXRKyYVABBtiMZN8b0xrtdoxBli1ed3mFR1EaX3FzK9Mw7dddfgGnEghAbv7/mya7wrZmqk\nmjd5Xwr3XCidKgBwNJjgLAreSA9NHoTB8YPRy5gmb/O8I9G6aKSZrseIxOGINqjji4ItEryiLvSi\nvwJ24rqgp+la/G5ABm6Man18N6yCwXEc5s2bhxkzZmD8+PHIyMhAWloaFi9ejJ9/lmanPHbsGEaO\nHImNGzfi5ZdfbnG2VJe41gd0moMy6M3XqV/waKMRv79H7VsRLFKgSyeaIIoMymtsMDD+DRLLMPID\nyYoaq301xiA6Qo8eSdEqYYmLiAr5QfYlWq8QrhDP0aNTPPp3D5DPq8CzJnqUydtyJ+ok8eAEEyxW\nqQ+v1WvrFBPhnRvMpbFfn4iYCPd2RV2kxMXhum5SfWtNW90UUTrvS9k1LjSLVXTp0SMmJehxqUnx\nuHvINZh4R09527VxUh167j8hxoR4k8Z1FX8XTlC4EN3TSAiNMeifJj2HKZ28jWSEzoTru0uNT3PW\nslaWCUCTfnsVvB6CNQQ/msZz1ruL1LjGGmLw7EMDcV96GjpHxvodp/xtjMF7LUaQnhUjL93vrwf3\nQJTB6zoWeZ13UFwz0mUBAO6gP8MKIYun4OL8OkJaaK7q524vYo3eZyFa75uQwqjKYlLk9nv+dt2j\ngz+XzSHsA/fS09OxadMmbN68WR6DMWvWLIwaNQqAZIVkZWXh8OHD2LNnD9atW9fsa/CVKUiMb33K\nWHNQCUZliqoH7OKlqRfmPupN8R3V93p0jUhCDOMVlwifwRsjuqbD4RLkjBQjEwlnaU8AwJCkWyEK\nDIT6BMRGGRAdofc+/ACiDRFBsjcAoV7bjZZolHp5fHViyIJxXUoXJEcGdy96Jg5U+nRv6poK0R4B\nV0M0KmvdKZxOdUxo4rWSrzVSF+neb8CAToOk+yi/BqLAontUD0SapJdNVNTFqFt6IjFW+h0fQDDi\nWe2yj+s7AADgqkhBbERoz5TI69A3uWfQ43p1jcNDY67H5BHXytuu75wK0WGE0Cg1jF3iTH695ju6\njgDAyFlOejECrkqpIeDLu0PkOSQZU+SBYUadV5xNughc00X6uwcadxPPddHcLtRK22+JH4S+qf5j\nebQQeZ2qVx/wOI1nNS2uJ6L1UeihaOSUsULe5kL5zxYfwYjBzZ1vAgBwlu4QXTpEuweu9E5NwH8l\nSfd+7suj4K2C97qMdl0IARr4G+Ikq8lZ2rN54lmfEPQwrcHGQl0niC49ronpJm9jfSYkdZlTAV7x\nt2Yj0TO6p3TpmkQwvEF7rE0ruKJHeg/pOgjWQ6MhWDpfBgvD+yALNYmwHR4tN8icTlKPa7t5X5zr\nuyXghSFPYWrqA/K2CL23QbIeGo1f9xgNp0uQXUscy8B14Ub0a5yGaTfeB9uRURBtUeiRGC1NeKd4\n6WKMkSoB0cKeeytsx37lF9hOik6A9eAYOM4MDCn3HpB6rnpOj7dG/N1vn+PczZK/GN6pv5Vz2XSK\nikavugxYTvVFbqE7k8Z9LybOhNeH/y9+/V8jAQBdIjoDkHrnM/o/gLdGvILOllthO3InkqO7yIKh\nrIsInSnoiP+HUh/FX4fOhu8MrCP7XIdOhRMx7YapoSc28Dpc0zkefx86129XVNlQGFnpPFqLeyVG\nxeKGhkkYl5KBJc+MQIRRJzcMXSI647Xh8/DwgLsAeEVVzxjhPNcP1oNj4Mzvg/8d+iJenj5cPqey\nxxqpM8lWmmYaLYCpPR7Bi0Nm+W0fe8tNGGB7GI8PeACdogKnW6vgdYDLiGHib/12/bHf7+TPCdH+\nYhxnjMXLt72AqTdM8tvXPbI7os6MgL40H+/NGiH3tmP0Ufjvfo/gzeEvgynuB9vRkRjRN1X+nefe\nr/3tLWC5qKAWxq3Mfegv+ns5pgwdAOvBMXBduDF095xLB9ERAduRkX777DlD5M+cxrLSoiMCMQV3\na86bdnPnm/Dszc9BqO6qeuc5wYTZt85Er4oH4Mzvi8TSe2DStS45x5crcvJBDxzLyq6K+Ji2rZhg\nqNNqGcQao5GSHI/C+lokd/YvC8fqYNKZ0LdnEgApkBtl1EOeecBlQKRRauQ8QWi7aAXAIMoQCaNe\nJ99rjyS3Ga54WCL1EcGtA14H0RorNa6KoGqXmEhvT0UI7ZHwTMIYqfd/8QVrtFyW7okm9L61BzJu\n74m/uudnjDQacH23KJw6Xw8R0pQJid1icUGQLBGlmyHW/bnR1QiO5RDJRqJrQhRKKqyIidAj0ugu\nr+i99widSTVOQ4teyQmIjtCDY1iVFWIy6DH/Manx/f5saFMpiLwe8TFGxEb6u82eGT8ciw+fhF2w\nq6aTl6+nM+GZ+wMPNlUuHSo6jUBEAwTWLv3t3YMMk+PiVb9RTvFvUghGIPdccnwskqP9e9YPjvb6\n30NNI/dYxz27dMFen+nFukWnyEHqO25OwVqfBQQjdCbN5wkA9DoO0ZX7UV5Wimee+j2c3VmwPU3I\n/PJbXEg9hTNn8vC/ry3Fq/94Cf/O/RKfOhx44IFp0PWU7v3Uwl1IHdUfMXE8sv+1B7GpXWAprIQ+\n1oieD/cH656S/vF7+sPBO/H40s9gzsqHyIvgInWIfNMhWQwuB4o374Wt6iLAMEi+sxfi+iSi7nQl\nLv50DqIoQhepR9pjA1F+6AAYZz5uHT4eJQByl+xFr9/eAkBE1c8rwZ8R0XihDqNe+hXefvt15Oae\nxPnKAsT1TUJc3Gjc0ee/cDo3F4sX///2zjwgynJ7/J+ZYdgZkE1kEVEUccEdUMktrpgrXEWvZup1\nrVxyqUS+Zd+ytG96vdXtdjXNTLMsb9mvm7bp1dJETZOstMUV0QABkX0GmOf3xzADA4MMCiLwfP5i\n3vV5D+/7nOc85zzn/I2iomIydVn4P9SVr/7+EYOejDTJ5vzmU/iO6YitnwIbpQ2d/Dw5cyGP0IDW\nFiR5ZzRthaFQsmZeJEXaUlMd5rtF1YV7M0Z05nBBeSoGC4kAjVla1TZKEh/qQ+r1fHKdfuZcpSzQ\nduU5qYwjyTJRitpGiZ+n+eiuTXlo76Du/hy5blgx7WBT+5SUh8aBrNxibO2gcr48D03FR9on2IfT\nxbWXfDQqDIuUVkyX6SljSrS5E9TTxQkP/4pOblAPX0q8s7iSWt3sdrUzWGk6fUW8Ukd/N36+mI2P\nh2MNFoZDrXVRnMsTRtY0ZQV1WJxZamO6XlVcbJ1NI35LU0KWLCFjidiq77TQGd4LPbdeBFe5/oqj\njQM25fewVB8DwENjX+vC2arWVklKCGXZ1ac7nNVO4KjiP99eorjYfGT94i9nKSoeBMAXP1W00TgC\n/0yfxoPR5lM4xvdBqVDwyCMLuXTpAlu27OC15M18d/IYWZczmLd6AT4+hrb8428v4eLiglarZc6c\naQxYEGO4kALG39eZayWlnMguos3kQHzHBnPpg5+4eSaDVmEVz2KjVOEU6EbHuQZFnnXyGh/9+11a\nufTlt+N7UTnbEzLfkAKkrLiU0gIdqZ/8QvCsPti62VNWntbF+E56t3LkWnkbjGRl/EH7sb3wHx2C\nq4cb8+bNx8XFhUf3PcH5raeYMdONmHB/pj44gVWr/o+QkM68cHgdV4vT6TiwG/v37SXhwVno8q+z\n/AclDq2daeVm6D9G929HgJczYcEeNf9Db5MmrjBUtG7VsOsvakKpUBLh04fDxw2LYezUSpNSqGz6\nz+gymYOp3xLSqiKvULCfK8F+ruSXuHH2xq/8/p3hZVUplSyOD+M/R+2xdSlhVPtoggZ2MCmSByLa\noivRm+ovuDk7QnmNFQcbe/Q3PdHnu1Ka1g6hszdLuT6gTTgTB0VSWqbn8W8/N20vve6Lb1jFiN5W\neetO0sfRG7VKTadWFaPPv4T8mTNZv9LFI4Sfs85yXOtoUl6VE9xNDhnP0T9O0MmzLcK94uvx93Im\nOHAoF3IvE99xrNn9YtoN43LuFTPTfHi/AAb39MXBzgbHcme6qDIl1ce7J99eO84D7e7HwcaeV069\nYdpvn1sRjVKZKN8Is9+1TUn5OvmgVKiY8ZcHTJ37hI5juXDzEu1d23E+5yLOaieLU0ITO8WSfP0n\nPByqz3GP6zCC9MIMpnSeYNo2d0wXvjljh975JJqcXmTXsA4AzKtO2tvYE+UXwanrpxnXYSR//yDZ\nUN60nJI/2pkWo1ZmqH+U2W/7WmRhTLBpo7JBWa7wHNQOlOnLsFGqKNWXoUCBQqFACPNAWDuVPXpR\natH5G99pHFt+3sGDnSdAXoXC+0tIHJfPXsC+c6hJWQB88MG7HDpkqLmRkZFBQWaeYb2FgB5BrelU\npOFD97dZ/KcFFJYU8sKh59HdMPjRHihPFa9UKNHdLOba++cozdfiqHTgYuB5fPsNJjnzdwL6RWH8\n8FT2NuT+molzOzdsy1dSt/MKRKlQkVpyBYUSfNwdKLnQAVHyPX1ahYO6mJutU3D0MwyGSvQl7N//\nBZ988jE5RTmU5WhxFHmkXrmMp6cXISGGAJq/9pzK9rPvMzU+nmUPL2D+/MVs2vQO8eMmkuacR2z5\nN6JUKujVyTqfU11pkgrDaNZWHY3ebaZ1mcTBTwy1I1QqpWkkV3nBUD+fXvTzsRxy6ax24vG+85nz\n34OmbWEdPMujXQZUOz5+qHmIXuVRtIONPQvjelGk7c7mMwZbf1HPubyabOgoHww1dD5qmwqZ6S52\npex6ABonW8I6eHD6fJbBF3KL2Zy2Gn+md/mL2bb7/CK5zy/S9Pfxvf81WRiVZRHlF0GUX3mnrDQ4\neDNvFhPg7YSrnSNP9l1Y7X4aWxce77vAbJtSqcChfCpKbSxeJMwVhqPagYR+j5m2ze72EJt/2g6A\nT1F4tfvM6DK52v/pllYU0N41kMmdx5ttGxoQxdCAKNPfQCULo0IWg/0HMNi/+v8YoJW9WzWfQmRX\nHyK7+gAD2fbFr4Ah99Hi+LBq59tUeS9cbJ3L/TWgv3kV3cUu2AYZpkVn9x5f7fy53afRw6ub2Tb7\nKrJQt/3VbF3IsID7GN+x9gjHxMOruKnLY0CbcPb/P0Mk0HMzBtSY58jb0dP0f0zLq1g57+ngwZTO\n49mZXFFX4tSpk3z//QneeGMrtra2LFw4j8rGmL2NA2pbG3w0renUyjBoCPUM4ec0w/cyun2M6dir\ne37De2BbEiclUHw5j7fe2kRg6/LpwSrTtlVDvrt6hjKmfQzjNvwPQgjUNipKr3ZEX6xioO9AvFzt\nSXLYh1KhRC/0ZKZfZ+/OXbz55nacnJxZvfpZdDpttev6OvuwvFwW/fpFcOjQQQ4c2MfmzdtxcWm4\nNWiVaZpO7/LBaWMrjMooFFQaSVqfNlylvP1nqNox9OroxYBubVgzL5LnZlbvFKuiLzS8ZPa2Khb8\nuTuLJoTRJeDWkU9WO4ItWBhVeWp6X5ZP6YV3fViJVaakqu2uNPVkmsaqhKXww9oUhrWysLmN9+JW\nGG0zJ3sbUyhtZdQK8/ei+gUqeqLw0Orz3LcnC+um7yrLomuQQWG0stL/6OjoSGFhzaOZgoJ8XFxc\nsLW15fLlS/z8808mqx/AVmmMqKvcE1tem6LXlmGjscPPuQ2fffYpAKP6B9ItrA/ZP1WsWi8rKsEp\nQEP+pRx0OYbsC/oiw/95YJ/OaG9epXt7d4pvplJSdANHu4piakZZFBYW4uDggKOjE9nZWRw9egSA\nwMB2ZGVl8ssvZ03H6cvrqI8ePY6XX15HaGjXu6YsoIlaGEoUlCEslrhsLOzUKtztDPPybna3XgVe\nlRHhbU3RRHXB3MKo6CSN03QXbt46Umh8eA/Ss3SmKa6ewZ5cyr21s9jGQrnZqswaFcpnqRe4QSau\ndjW/zBpHWzRt66cSWGULw1JkiLFzF0JhSA5XhdaO1U14S4qnMtY6gr0cPEjJSzVz5t8Jrs6G+3q3\nstw+o6WrVqqrTfMsndSDXT/kkImluH4D7vbVp8nqS3l6O3qRVXwDZ1tHpsX3oEwvrB40aTSudO/e\ng+nT/0JExAD69x9otj8iYgAff/whM2ZMoW3bQLp1627qIxQKpclPU9lfY5yCrfr/bz20HZd3/khi\n0jK6dOlGWtofONjZ8NcZs0l8IYFfXzsGSgU+Q4NwDfXCf2xnLr33I0JAgfc1xr0+ikUz47n221Ge\nXj6Pm8Vu2Dp5mSxjhUKBj6M3KXmpBAa1I7tjCA89NAlfXz/CwnoAYGNjw7PPruHvf38JrVaLvb09\nL7/8Ovb29oSEdMbJyYlRo+48y3ddaJIKw7D0lHrJcHunPDWtLz9fzCLA2xlvj6GUCT2D/PvX6RqV\no1HqgrFjUCqU2FpwngZp2jKyXTTdvbqYbV/Ycw452ptEtqk+l19bJ1lioQRqVQZ2b0Ovzg/y5eUD\npmmZhiKgtaET9nXXcB3DXLsly7OLRwhlf3SgJNMHx7AKWT3aYyZFJUUW667X1kmKGpzIVZnYKZZW\n9m6mUOE7JaZfW7S6Mob29rO43+jDsNT+bkEehAaO49OLDkT49DHbN7f7dMpEmUX51fZeqKy09h8K\nnciBK4cZHjgUpVJhVsLVGlauXGX2u1evimdQq9WsW2eet2nvxa84e/EiUStGotG4otG48vbbO037\nl89NYM/FrxhYxX+1dMISbOJVhHl1Ndvu6a7BN3I0duX5voxoOnqg6WhwMj/Y2RA6b+jg/wnAzBcN\nU9e+bQzrKt5+eyfZxTf4JjWJoQH38UBitMXn7dw5lI0b36q2PTPzOkII+vWLtHBWw9E0FcY9RHtf\njWm9hb2NHbHBI+/avdXloycHG8tRLgqFglHtqxek6uxevbCPkdo6Sa0VCgMM4bZ3QxZd27mzYmpv\nrul/4YNz39XYsSkVSsZ3Gsl7V36nf9cKJ2lXj5qzHdbWSWr11iU9dLZ1Ii54lFXHWoOdraqaP6sy\nxiipmtqvUqosxvf3qNI5Vqa290Jn5SpyVzvNXf1GjBaWQ03huiq1xfb09q7uGwJDgMbU+7uy6+oJ\ni/vBfPrTyEsP96ekSu13d/tWtyWLzz/fw6ZN/2LRoqV1PvdOaZIKw5i6+25mw70XsVEZRsrVcvvf\nAbVdS1dmXVnWu0lHfzdupBmmFm7VsUX39WdYHz+rp0BqVZ6l1mfJvZsYpypra39dqH0gca/Kov6/\nkVB/L2PMgUV0FmThWY+ZKEaMGMWIEfU3AKkL947XuA40/kTUvUFDdAxqlWW/h/HDC27V3uL+xsbY\nvqrRPJVRKBR1CjKoKZTU6LsIcg20uL+xsbmLCsMoi7Yu/hb3NzYNIosarBXjtHAb5/pNx3Ev0SQt\nDGNioqppjVsatU093CkrIx6nVJRRUFJAB9cgLuam0P6e7SQN03OO9dgxVPZrPNd/BcVlxRSUFNDe\ntR2Xcq/QwbVdvd2rPmmITrKyU/v5AYkUlhZRUFJIkGsgKbmpdHBrV2/3qk8qZFF/30hla+W5/ivQ\nlmkpKCkgUNOWq/nX7tmBRH3QJBWG0cJo6QrjVs7N+sDN3s0sXDLYrfbstI2FNRbGneBu72bmJ7q3\nZVH/70XlZ29l70YrKlbq36vKAhreCq+68LI5KwtoslNS5S9vy9YXDTJ6qoy1kS/3AkZZODaQLO6F\niDxrMc3bN5AsmhINoTBaMk2nR7BAS7cw3O1b4ah2IMDFcnjl7RLqbsj9dC+tc6kNTwd37G3s6l0W\nHVzbWQxZvpfxcvTEzsaOAGff2g+uA/7Ovrio62ctye2Qn5/P7t3/rtM53o5e2Kls8Xfx5YMP3kOr\nrZ+gDU8HDzzt3evlWk0JhWiCoUbTPlxMcanW6nQEzRl3D0eys2692K6uCCHQC73FtQn3MlIWFTSU\nLBozJc8ff1xj+fIlbNv2fp3OM8oiPn4sb765HY2mbgtrLWFM5FhXWZSVlaFSNd67dKdlrJuoD0M6\nvY00REemUCialHVhRMqigoaShaIRYxQ3bHiNa9euMnPmg/TtG8Gjjy7i3Xe3c+DAV5SUlDJo0BBm\nzpxLcXExK1cmcP16Bnq9noULF3DpUiqZmddZuPBh3NzceOWVf5lde+vWzXz77SF0Oi3duoXxxBOJ\nAFy9msratavJyclBpVKxatWL+Pr68d672/nyy89QKpVERg5k3rz5LFw4jwULlhAS0pmbN3OYPXsa\nu3Z9wmeffcqRI4fR6bQUF2t58cW/kZCwjPz8PEpLS5kz52Giosoz9n72KTt37kCpVNChQ0eWLl3O\n9OmT2bnzI1QqFYWFBeW/dzeK4mmSCqOS11sikTQCH537lFMZP9brNXt5d+fPwaNr3F85vTnAd98d\nJTU1hU2btiGEYPnypfzwQzI5Odl4enrx0ksvA+DgoKBvX8H777/HP/6xEY2mekXA8eMnMWPGbABW\nrVrJkSOHGTAgimeffYpp0/5KVNRgSkpK0Ov1HD16hMOHv2HTpm3Y2tqSl5dXQ4srlOvPP//Itm3v\n4+zsjF6vZ82adTg6OnLzZg7z5hmuf+HCed55Zyv/+tcWNBoNeXl5ODo60rt3H5KSDhMVNZh9+75k\nyJD7G81KaZIKQ1oYEonk+PFjfPfdcWbOfBAhBEVFxaSmphAW1pN//vMVNmx4jf79o4iOvo+iojwM\nI0zLfcbJk8d5993taLXF5OXl0b59B3r27E1m5nXT6F+tNviyTpw4zqhRY7C1NUQQWpP8r1+/CJyd\nDf4fvV7Pxo2vkZx8CqVSQWbmdW7cyObUqRMMGXK/SaEZrzt69DjefXc7UVGD2bv3PyxfXr2y492i\nSSoMiUTSuPw5ePQtrYG7gRCChx6awdixcdX2vfnmOyQlfcvGja/x228/Eh//UI3X0el0rF//Elu2\nvIOnpxdbtryBTqejJuVicPtWn5pTqVSm/GKG8ytwqFQf/quvPicnJ4e33tqBUqkkPn4sWq2uxswV\n3bv3IC3t/0hO/h69Xk9QUOMtnm2SUVJyRkoiaXlUTW8eERHJnj2fUFRkSCtuGKnfIDMzEzs7O4YP\nH8HkyVM5c+ZM+flOFBQUVLuuTqdDoTBkwy0sLOTgwf2m4729W3Po0EEASkpK0GqLCQ833FerNRRe\nys3NBaBNGz9++cVwrwMH9tX4HPn5+bRq5Y5SqeT770+Qlmao89GnTzgHDuwjN/em2XUBYmJG8r//\n+z+MGjXW4jXvFk3Swujg3o7T6WfxtJCGWSKRNE+qpjd/9NFFXLp0iYcf/itgUChPP72K1NQr/POf\nr6BUKrCxUfPCC4YMt2PHxvL444vw9PQyc3o7OzszZkwc06ZNok0bX0JDK5IwPvXUs6xdu5rNmzei\nVqtZtepFIiL6c+7cb8yaNQ1bWzWRkQOZO/dRJk9+kKefXsEXX3xGnz79anyO4cNHsHz5UubMmUZw\ncAiBgYZFoEFB7Zk2bSYLFsxFpVLRsWMIiYnPlJ/zAJs3byA6unoy0btJkwyrzdXm89WZIwzw7Wex\nrGNLwsvLhevXa3K6tSykLCqQsqigOcjiwIF9fPvtIZ566tk7uk6LDKvV2DnXueaERCKRNEVefnkt\nR48msW7dK43dlKapMCQSiaSlsHjxE43dBBNN0uktkUgkkruPVBgSiUQisQqpMCQSiURiFQ2uML75\n5htGjBhBTEwMb7zxRrX9Op2OJUuWMHz4cCZNmsS1a9caukkSiUQiuQ0aVGHo9XpWrVrFm2++yaef\nfsqePXs4f/682TH//ve/cXV15csvv2T69OmsXbu2IZskkUgkktukQRXG6dOnCQwMxM/PD7VazahR\no9i/f7/ZMfv37ycuzrC0PyYmhqSkpIZskkQikUhukwZVGOnp6bRp08b0u3Xr1mRkZJgdk5GRgY+P\noWi6SqVCo9GQk5PTkM2SSCQSyW3QoArDmkXkVY8RQjSpcpgSiUTSUmjQhXs+Pj5mTuz09HS8vb2r\nHZOWlkbr1q0pKysjPz8fV9faK2Ld6RL35oSURQVSFhVIWVQgZVE/NKiF0b17d1JSUrh69So6nY49\ne/Zw//33mx0zdOhQdu/eDcDnn39OZGRkQzZJIpFIJLdJgycf/Oabb3jhhRcQQjBhwgTmzp3Ls8Qs\nGwAACZ5JREFUq6++Svfu3Rk6dCg6nY4nnniCs2fP4ubmxvr16/H392/IJkkkEonkNmiS2WolEolE\ncveRK70lEolEYhVSYUgkEonEKqTCkEgkEolVNDmFUVtuquZGYmIiAwYMYMyYMaZtN2/eZObMmcTE\nxDBr1izy8iqqiT3//PMMHz6ccePGcfbs2cZocoOQlpbGtGnTGDlyJGPGjGHbtm1Ay5SFTqcjPj6e\n2NhYxowZw2uvvQZAamoqEydOJCYmhqVLl1JaWmo6vrnna9Pr9cTFxfHwww8DLVcWw4YNY+zYscTG\nxjJhwgSgnr8R0YQoKysT0dHRIjU1Veh0OjF27Fhx7ty5xm5Wg/Ldd9+JM2fOiNGjR5u2vfTSS+KN\nN94QQgixceNGsXbtWiGEEAcPHhRz5swRQgiRnJws4uPj736DG4iMjAxx5swZIYQQ+fn5Yvjw4eLc\nuXMtUhZCCFFYWCiEEKK0tFTEx8eL5ORk8dhjj4m9e/cKIYRYuXKleO+994QQQuzYsUM888wzQggh\n9uzZIxYvXtwobW5I3nrrLbFs2TIxb948IYRosbIYNmyYyMnJMdtWn99Ik7IwrMlN1dzo27cvGo3G\nbFvl/FtxcXEmGezfv5/Y2FgAevToQV5eHpmZmXe3wQ2El5cXoaGhADg5OdGhQwfS09NbpCwAHBwc\nAMOIubS0FIVCwbFjx4iJiQEMsti3bx/Q/PO1paWl8fXXXxMfH2/advTo0RYpCyEEer3ebFt9fiNN\nSmFYk5uqJZCdnY2npydg6Eizs7MB87xcYJBPenp6o7SxIUlNTeWXX36hR48eZGVltUhZ6PV6YmNj\nGThwIAMHDiQgIACNRoNSafikfXx8TM/b3PO1rV69mieffNKUUujGjRu4urq2SFkoFApmzZrF+PHj\n2bVrF0C9fiNNqqa3kEtGbokl+TS3vFwFBQUsWrSIxMREnJycany+5i4LpVLJxx9/TH5+PvPnz69W\nNgAqnreqLEQzytd28OBBPD09CQ0N5dixY4Dh+ao+c0uQBcDOnTtNSmHmzJkEBQXV6zfSpBSGNbmp\nWgIeHh5kZmbi6enJ9evXcXd3BwwjhLS0NNNxaWlpzUo+paWlLFq0iHHjxhEdHQ20XFkYcXZ2pl+/\nfvzwww/k5uai1+tRKpVmz2uURV3ztTUFvv/+e/773//y9ddfo9VqKSgoYPXq1eTl5bU4WYDBggBw\nd3cnOjqa06dP1+s30qSmpKzJTdUcqToSGDZsGB999BEAu3fvNsng/vvv5+OPPwYgOTkZjUZjMkWb\nA4mJiQQHBzN9+nTTtpYoi+zsbFOkS3FxMUlJSQQHBxMREcHnn38OmMti2LBhzTZf29KlSzl48CD7\n9+9n/fr1REREsG7duhYpi6KiIgoKCgAoLCzk8OHDdOrUqV6/kSaXGsRSbqrmzLJlyzh27Bg5OTl4\nenqycOFCoqOjeeyxx/jjjz/w9fXllVdeMTnGn3vuOQ4dOoSDgwNr1qyha9eujfwE9cPJkyeZOnUq\nnTp1QqFQoFAoWLJkCWFhYSxevLhFyeLXX38lISEBvV6PXq9n5MiRPPLII1y5coWlS5eSm5tLaGgo\na9euRa1Wt5h8bcePH2fLli1s2LChRcriypUrLFiwAIVCQVlZGWPGjGHu3Lnk5OTU2zfS5BSGRCKR\nSBqHJjUlJZFIJJLGQyoMiUQikViFVBgSiUQisQqpMCQSiURiFVJhSCQSicQqpMKQSCQSiVVIhSFp\n0kycOJG4uDhGjRpF165diYuLIy4ujsTExDpfa/bs2Valu16xYgXJycm309w6cebMGb744osGv49E\nYi1yHYakWXD16lUmTJhwy+yjxlQRTYVdu3aRlJTE+vXrG7spEgnQxHJJSSR1ISkpibVr19KzZ0/O\nnDnD/Pnzyc7OZseOHaaCOgkJCYSHhwMwePBgtm7dSlBQEFOmTKFXr16cOnWKjIwMRo8ezeLFiwGY\nMmUKjz76KFFRUTzxxBM4Oztz/vx50tPT6d27N2vWrAEMuXmefPJJbty4QUBAAGVlZQwbNoxJkyaZ\ntTMzM5Nly5Zx48YNAKKiopg9ezavv/46hYWFxMXFERERQUJCAqdOnWL9+vUUFRUBsGjRIgYNGkRK\nSgpTpkxh9OjRnDx5Ep1OxzPPPEPv3r3viqwlLYQ7KdYhkdwrpKamisjISLNtR44cEV26dBE//vij\naVvl4jLnzp0TQ4YMMf0eNGiQuHDhghBCiMmTJ4tly5YJIYTIzc0V4eHhIjU11bTv0KFDQgghHn/8\ncTF16lRRUlIitFqtGDFihDh27JgQQohHHnlEbNq0SQghxJUrV0SvXr3Ezp07q7V98+bNYuXKlabf\nubm5QgghPvjgA7F06VKztsfGxoqsrCwhhBBpaWli0KBBIj8/X1y+fFmEhISIPXv2mJ59yJAhorS0\n1HohSiS1IC0MSbOmffv2dOvWzfT70qVLvPrqq2RkZKBSqcjIyCAnJwc3N7dq5z7wwAMAuLi4EBQU\nREpKCn5+ftWO+9Of/oSNjeFT6tKlCykpKYSHh3Ps2DGef/55APz9/U2WTFV69uzJO++8w7p16+jX\nrx9RUVEWjzt58iSpqanMmjXLlJBSpVJx5coVHB0dcXBwYOTIkQD0798flUrFpUuX6NChg7Xikkhu\niVQYkmaNk5OT2e8lS5bwzDPPMHjwYPR6PWFhYWi1Wovn2tnZmf5WKpWUlZXV6Thr6yz06dOH3bt3\nc+TIET788EM2b97M9u3bqx0nhKBr165s3bq12r6UlJRq2/R6fbOq9SBpfJqOB1AiqQVhRfxGfn6+\nKTvpzp07a1QC9UF4eLgprfTVq1c5fvy4xeNSU1NxdnZm5MiRJCQk8NNPPwGGWhfGNOYAvXv35ty5\nc5w4ccK07fTp06a/i4qK2Lt3L2AoUQoQGBhYvw8ladFIC0PSbLBmNJ2YmMjcuXNp06YNERERuLi4\nWDy/6rVq2ner455++mmWL1/Onj17aN++Pb179za7n5GkpCS2bduGSqVCCMGqVasAGDhwIG+//Tax\nsbFERkaSkJDA66+/ztq1a8nLy6OkpISAgAA2bNgAgKenJ7///jvx8fHodDrWr1+PSqWqVSYSibXI\nsFqJpIHQarWo1WqUSiXp6enEx8ezY8cOAgIC6v1exiipw4cP1/u1JRIj0sKQSBqICxcusGLFCoQQ\n6PV6lixZ0iDKQiK5W0gLQyKRSCRWIZ3eEolEIrEKqTAkEolEYhVSYUgkEonEKqTCkEgkEolVSIUh\nkUgkEquQCkMikUgkVvH/AcQ/YGad+SX7AAAAAElFTkSuQmCC\n",
+ "text/plain": [
+ "\u003cmatplotlib.figure.Figure at 0x7f971b401110\u003e"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "#@test {\"timeout\": 90}\n",
+ "with tf.Graph().as_default():\n",
+ " hp = tf.contrib.training.HParams(\n",
+ " learning_rate=0.05,\n",
+ " max_steps=max_steps,\n",
+ " )\n",
+ " train_ds = setup_mnist_data(True, hp, 500)\n",
+ " test_ds = setup_mnist_data(False, hp, 100)\n",
+ " tf_train = autograph.to_graph(train)\n",
+ " (train_losses_, test_losses_, train_accuracies_,\n",
+ " test_accuracies_) = tf_train(train_ds, test_ds, hp)\n",
+ "\n",
+ " with tf.Session() as sess:\n",
+ " durations = []\n",
+ " for t in range(burn_ins + trials):\n",
+ " sess.run(tf.global_variables_initializer())\n",
+ " start = time.time()\n",
+ " (train_losses, test_losses, train_accuracies,\n",
+ " test_accuracies) = sess.run([train_losses_, \n",
+ " test_losses_, \n",
+ " train_accuracies_,\n",
+ " test_accuracies_])\n",
+ " if t \u003c burn_ins:\n",
+ " continue\n",
+ " duration = time.time() - start\n",
+ " durations.append(duration)\n",
+ " print('Duration:', duration)\n",
+ "\n",
+ " print('Mean duration:', np.mean(durations), '+/-', np.std(durations))\n",
+ " plt.title('MNIST train/test losses')\n",
+ " plt.plot(train_losses, label='train loss')\n",
+ " plt.plot(test_losses, label='test loss')\n",
+ " plt.legend()\n",
+ " plt.xlabel('Training step')\n",
+ " plt.ylabel('Loss')\n",
+ " plt.show()\n",
+ " plt.title('MNIST train/test accuracies')\n",
+ " plt.plot(train_accuracies, label='train accuracy')\n",
+ " plt.plot(test_accuracies, label='test accuracy')\n",
+ " print('test_accuracy', test_accuracies[-1])\n",
+ " plt.legend(loc='lower right')\n",
+ " plt.xlabel('Training step')\n",
+ " plt.ylabel('Accuracy')\n",
+ " plt.show()\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "A06kdgtZtlce"
+ },
+ "source": [
+ "# Eager"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "hBKOKGrWty4e"
+ },
+ "outputs": [],
+ "source": [
+ "def predict(m, x, y):\n",
+ " y_p = m(x)\n",
+ " losses = tf.keras.losses.categorical_crossentropy(tf.cast(y, tf.float32), y_p)\n",
+ " l = tf.reduce_mean(losses)\n",
+ " accuracies = tf.keras.metrics.categorical_accuracy(y, y_p)\n",
+ " accuracy = tf.reduce_mean(accuracies)\n",
+ " return l, accuracy\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "HCgTZ0MTt6vt"
+ },
+ "outputs": [],
+ "source": [
+ "def train(ds, hp):\n",
+ " m = mlp_model((28 * 28,))\n",
+ " opt = tf.train.MomentumOptimizer(hp.learning_rate, 0.9)\n",
+ " train_losses = []\n",
+ " test_losses = []\n",
+ " train_accuracies = []\n",
+ " test_accuracies = []\n",
+ " i = 0\n",
+ " train_test_itr = tfe.Iterator(ds)\n",
+ " for (train_x, train_y), (test_x, test_y) in train_test_itr:\n",
+ " train_x = tf.to_float(tf.reshape(train_x, (-1, 28 * 28)))\n",
+ " train_y = tf.one_hot(tf.squeeze(train_y), 10)\n",
+ " test_x = tf.to_float(tf.reshape(test_x, (-1, 28 * 28)))\n",
+ " test_y = tf.one_hot(tf.squeeze(test_y), 10)\n",
+ " if i \u003e hp.max_steps:\n",
+ " break\n",
+ " with tf.GradientTape() as tape:\n",
+ " step_train_loss, step_train_accuracy = predict(m, train_x, train_y)\n",
+ " grad = tape.gradient(step_train_loss, m.variables)\n",
+ " opt.apply_gradients(zip(grad, m.variables))\n",
+ " step_test_loss, step_test_accuracy = predict(m, test_x, test_y)\n",
+ "\n",
+ " train_losses.append(step_train_loss)\n",
+ " test_losses.append(step_test_loss)\n",
+ " train_accuracies.append(step_train_accuracy)\n",
+ " test_accuracies.append(step_test_accuracy)\n",
+ " i += 1\n",
+ " return train_losses, test_losses, train_accuracies, test_accuracies\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ },
+ "height": 789
+ },
+ "colab_type": "code",
+ "executionInfo": {
+ "elapsed": 56025,
+ "status": "ok",
+ "timestamp": 1531163800231,
+ "user": {
+ "displayName": "",
+ "photoUrl": "",
+ "userId": ""
+ },
+ "user_tz": 240
+ },
+ "id": "plv_yrn_t8Dy",
+ "outputId": "68be955d-61dd-43e4-b540-3794e3c8f990"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Duration: 4.2232978344\n",
+ "Duration: 4.2386469841\n",
+ "Duration: 4.24286484718\n",
+ "Duration: 4.24036884308\n",
+ "Duration: 4.25758385658\n",
+ "Duration: 4.23242998123\n",
+ "Duration: 4.4213449955\n",
+ "Duration: 4.29613113403\n",
+ "Duration: 4.28209114075\n",
+ "Duration: 4.24192905426\n",
+ "Mean duration: 4.26766886711 +/- 0.055508619589\n"
+ ]
+ },
+ {
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYwAAAEcCAYAAADUX4MJAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzsnXdgFGX6x78zW7KbTSE9JIA0pQkIooCgqBx2qiK/O0XU\n8zyFAw/w8MSCFcuJCHqKoFiwIHIgIgooaGjSCU1aaCEJ6W1btszM74/ZmZ2tWchuQjbP55/s7szO\nvDPZeb/vU97nZQRBEEAQBEEQ9cA2dQMIgiCI5gEJBkEQBBESJBgEQRBESJBgEARBECFBgkEQBEGE\nBAkGQRAEERIkGARBEERIkGAQRIjs3r0bt99+e1M3o14KCwvRtWtX8Dzf1E0hogwSDKLB3HzzzejZ\nsyeqq6s9Ph85ciS6du2KoqIiAMC///1vdO3aFQcPHpT3yc/PR9euXeX348ePx/Lly+X3CxYswNCh\nQ9G3b1/ceOONmDZtGgDgrrvuQt++fdG3b190794dvXr1Qp8+fdC3b18sXLjQp43vvfceZsyY0aDr\n7NevH3766acL+s6HH36IuXPnYufOnRgyZEiDzi/hfY/8wTBMWM5FEErUTd0AIjpo06YN1qxZg/vu\nuw8AcPz4cdhsNo+Oi2EYtGrVCu+88w4+/vhjj8/9sXLlSqxevRqfffYZ2rRpg4qKCmzcuBEA8MMP\nP8j7jR8/HqNGjcLdd9/doGsQBCHsHW1OTg6efPJJOBwO6sSJZg9ZGERYGDlyJFauXCm/X7lyJUaP\nHu2z3+jRo3Hs2DHs3r273mMeOnQIgwcPRps2bQAAKSkpGDt2rN99g1W42bx5MxYsWIAff/wRffr0\nwahRowCIQjN37lz8+c9/xlVXXYWCggKsWLECd9xxB/r27Ythw4bhm2++kY/jbSXcfPPNWLx4MUaM\nGIFrrrkG06ZNg91ul7fX1tbi7Nmz6N69Ox599FGUlpbKVlBZWRkEQcDChQsxbNgwDBgwAFOnTkVt\nbS0AwG6341//+hf69++Pa665BmPHjkVlZSXmzp2LPXv24OWXX0bfvn3xyiuv1HsfS0tL8fjjj6N/\n//649dZb8e2338rbDhw4gLvvvhtXX301Bg8ejDfeeCPo+QHAZDLhmWeeweDBgzFkyBC888478v3P\nz8/H+PHj0a9fPwwcOFC2CInogCwMIiz07t0bq1atwqlTp9C+fXusXbsWX331FebOneuxn06nw2OP\nPYa3334bX331Vb3HfPXVV5Geno7+/fuje/fuYNkLH+Ncf/31eOyxx5Cfn48333zTY9vq1auxaNEi\ndOjQATzPIyUlBQsXLkSbNm2we/duPPLII+jVqxe6desGwNcaWrt2LRYvXgytVov/+7//w8qVKzFu\n3DgAwJYtWzBgwADodDosWrQIM2bMwG+//SZ/99NPP8XGjRvx5ZdfIikpCa+88gpefPFFzJkzBytX\nroTJZMLmzZuh0Whw5MgRxMTEYOrUqdi7dy9GjhyJe+65J6TrnzZtGrp06YL58+fj5MmTeOihh9C2\nbVsMGDAAs2fPxoQJEzBixAhYrVacOHECAAKeHwBmzJiB9PR0bNiwAWazGY899hiysrJw7733Yt68\neRg8eDCWLFkCu92OQ4cOXfD/i7h0IQuDCBsjR47Ed999h61bt6Jjx45IT0/3u9+9996L8+fPY/Pm\nzUGPN2LECDz33HPYunUrxo8fj+uuu85vfKIhjB49Gp06dQLLslCr1RgyZIhs0fTr1w+DBg0Kag09\n8MADSE1NRUJCAm666SYcOXJE3vbbb78FjVssW7YM//znP5Geng6NRoNJkyZh3bp14HkearUa1dXV\nOH36NBiGQffu3WEwGC74+s6fP499+/bhySefhEajQdeuXTF27FisWrUKAKBWq5Gfn4+qqiro9Xr0\n6tVL/tzf+SsqKrB582bMnDkTMTExSE5OxoQJE7BmzRr5e4WFhSgpKYFWq0Xfvn0vuM3EpQtZGETY\nGDFiBO6//34UFBRg5MiRAffTarWYOHEi5s2bhzlz5gQ95l133YW77roLHMfhl19+wfTp09GjRw8M\nGjQoLG3OzMz0eJ+Tk4P3338fZ86cAc/zqKurQ5cuXQJ+PyUlRX6t1+tRVlYGQHSRbdu2DU8//XTA\n7xYVFeEf//iHbDUJggC1Wo3y8nKMHDkSxcXFmDZtGoxGI4YPH45p06ZBpVJd0PWVlZUhMTERer1e\n/iwrKwuHDx8GAMyePRvz5s3D7bffjrZt22LSpEm48cYbfc4/YsQITJ06FYWFhXA6nRg8eLDcZkEQ\n0Lp1awCi9fHOO+/gnnvuQatWrfDggw82OLZEXDqQYBBhIysrC9nZ2di0aRNmz54ddN8xY8bgo48+\nws8//xzSsVUqFW699VYsXLgQJ06cCJtgKF1MdrsdTzzxBP7zn/9g6NChYFkWkyZNChofCcTBgwfR\npk0bJCUl+ZxHonXr1pg9ezb69Onj9xiTJk3CpEmTUFRUhL/97W/o2LEj7r777gsKnqenp6OmpgYW\niwWxsbEARKtDsv7atWsni/a6deswZcoU7Ny5Ezqdzuf8HTp0wA033ICYmBjs2LHDbztSUlLw8ssv\nAwD27NmDhx56CNdeey3atm0bcpuJSxdySRFhZfbs2fjss8+g0+mC7qdSqfCPf/wDixYtCrjPypUr\nkZOTA7PZDEEQkJOTg5MnT8pukwshNTUVhYWFQTt/h8MBh8OBpKQksCyLnJwcbN269YLPBYjuqBtu\nuEF+n5KSgurqaphMJvmzcePG4e2335bTjisrK7FhwwYAwI4dO3D8+HHwPI/Y2Fio1WrZukhNTcW5\nc+eCnl+6zszMTPTp0wdvv/027HY7jh49iuXLl2PEiBEAgO+//14OZsfHx4NhGLAsG/D8aWlpGDRo\nEGbPng2TyQRBEHDu3Dns2rULgBjTKSkpAQAkJCSAZdmLijsRlyYRtTCKi4sxY8YMlJeXQ6VSYezY\nsXjggQc89tm5cycmTpwoj0CGDRuGiRMnRrJZRJhRjjS9R5LBRsN33XUXFi5cCKPR6Hf/uLg4LFiw\nAKdOnQLHccjKysILL7zg4xcPZcR922234fvvv0f//v3Rpk0brFixwud7BoMBzzzzDJ544gk4HA7c\ndNNNGDp0aMBjBjtvTk4OXnrpJfl9x44dceedd2Lo0KEQBAFr1qzBhAkTAAAPP/wwysrKkJKSgttv\nvx1Dhw5FeXk5Zs2ahZKSEhgMBtxxxx1yJ//AAw/gqaeewtKlSzFixAg888wzQds2Z84czJo1C9df\nfz0SExPxxBNPYODAgQDEDLLXX38ddXV1yM7Oxty5c6HVaoOe/4033sBbb72FO++8ExaLBW3btsUj\njzwCQLSsJDFJTU3FM888g+zs7KD/G6L5wERyxb2ysjKUl5ejW7duMJvNGDNmDN5//3106tRJ3mfn\nzp1YvHgxFixYEKlmEESjUlFRgVGjRtUb1CeI5kZEbcW0tDQ5HdFgMKBTp04oLS2N5CkJoskxGo1B\ng90E0VxptKB3QUEBjh496tf/nJubi1GjRiE9PR0zZsxA586dG6tZBBF22rdvj/bt2zd1Mwgi7ETU\nJSVhNpsxfvx4TJw4EX/60598trEsC71ej5ycHMyePRvr1q2LdJMIgiCICyTi6QtOpxNTpkzByJEj\nfcQCEF1VUo74kCFD4HA4fIrYedMIGkcQBEF4EXGX1MyZM9G5c2c5I8Sb8vJypKamAhDr2gBAq1at\ngh6TYRiUlRmD7tNSSEuLp3vhgu6FG7oXbuheuElLi2/Q9yMqGHv27MHq1atxxRVXYNSoUWAYBlOn\nTkVRUREYhsG4ceOwbt06fP3111Cr1dDpdD61hwiCIIhLg0aJYUQCGjGI0OjJDd0LN3Qv3NC9cNNQ\nC4OmYBIEQRAhQYJBEARBhAQJBkEQBBESJBgEQRBESJBgEARBECFBgkEQBKHAZDJh5crlF/XdGTP+\nCbPZVP+OLhYvXoilS7+4qHM1BSQYBEEQCozGWqxc+a3fbTzPB/3um2++A4MhLhLNuiSgFfcIgiAU\nLFjwHoqKCvHww/ehX7/+GDhwED75ZBFSUlKRl3ccS5Ysw9NPP4myslLY7TaMHftnDB8+CgAwduwI\nfPzxElgsFjz55BT07HkVDh3aj7S0DLz++hxotdqA5z1x4hjeeut12Gw2ZGdn4+mnZyEuLg7ffrsU\nq1atgFqtRvv2HfDCC69i3749mD9/jmvdEwb//e8ij2V4IwUJBkEQlyzLNuZh19GGLYmgUjHgOPf8\n5Gu6puPemwNXxH788ck4c+YUFi/+EgCwb98eHDnyB5YsWSavAT9z5izEx8fDZrPhb397AEOG3IyE\nhAQA7oWrCgrO4cUXX8NTTz2D559/Gr/9thG33HJbwPO+8soLmDbtKfTufRU+/vhDfPLJQkyePA1f\nfvkZli9fDbVaLbu7li79AtOn/xtXXtkLdXV1QYUonDRLl9T3m042dRMIgmhBdO/eQxYLAFi27Cs8\n+OBf8Pe/P4TS0lIUFOS7triFqXXrLHTqJApTly5dUVxcFPD4ZrMJZrMJvXtfBQC47bY7kZu7DwDQ\nufPleOGFZ7B+/U9gWXGZ3p49e2P+/LexfPlSGI21jbYMbrO0MBatOoTOrQcirVXkTTCCIJqOe2/u\nHNQaCIVwlAZRrlG/b98e7N27GwsXfgqtVovJk/8Ou93u8x3lqJ9lVX73URKoStN//jMPubl7sWVL\nDj799CN88cW3uP/+B3Hdddfj99+34O9/fwjvvPM+2rW77CKvLnSapYUBAAVloWciEARBhEpsbCws\nFkvA7WazCfHx8dBqtTh79gwOHz7kd78LKdNnMMQhISEBBw7kAgDWrfsRV10lrl1fUlKMPn2uxuOP\nT4HZbILVakFhYQE6duyE++6bgC5duiE//0zoF9gAmqWFAQAFZWb0uTytqZtBEESUkZCQiJ49e2PC\nhP9D//7XYeDAQR7b+/e/Dt999z88+OBf0K7dZbjyyp6Kre4YhhiQDp2ZM1/AW2+9BpvNhqysbMyc\nOQtOpxMvvfQczGYzAAHjxt0HgyEOixZ9gL17d0OlUqF9+44YMGBQvccPB82yWu3w6atwbbd0PDby\nyqZuSpNDlTjd0L1wQ/fCDd0LNy2yWq1Wo0JJpbWpm0EQBNGiaJYuqYTsMpTXNk4aGUEQBCHSLC0M\nc/pO8J23wGbnmropBEEQLYZmKRgSFbV1Td0EgiCIFgMJBkEQBBESzVIwkrViOm1JTW0Tt4QgCKLl\n0CwFIzsuGwBQYa5p4pYQBBFtNKS8OQAsW/Y1bDab322TJ/8dx44dvehjNzXNUjCS9YkAgCorWRgE\nQYSXYOXNQ+Hbb7+GzRad7vJmmVabFp8EAKi1U3kQgiDCi3d584kTp+Crr5bg119/hsPhxA033IiH\nH34UdXV1eP75f6OsrBQ8z2PChEdQWVmO8vIyTJ78GFq1aoV58z4IeJ6ff16LL774FAAwYMAgPP74\nZPA8j9dffxnHjh0BwODOO0fg3nv/7LfEeVPQLAWjdWIyAMDoJJcUQUQzK/J+wL7Sgw06hoplwPHu\nghZ90ntiTOe7Au7vXd58167tKCjIx6JFn0MQBDz11DTs35+L6upKpKam4c033wEAWCxmxMYa8M03\nX+Pddz90lTv3T3l5ORYseA+ffPIl4uLiMXXqJGzZkoO0tAyUlZXis8+WAoBcztxfifOmoFm6pDol\ntwUAmJmKJm4JQRDRzs6dO7Br1048/PB9ePjh+5CffxYFBfno2LEzdu/eiQUL3sP+/bmIjTW4viFA\nWebcH0ePHkbfvv2QkJAIlmUxbNhtyM3dh6ysbJw/X4R33nkLO3b8Lh/TX4nzpqBZWhiZ8ekAp4Zd\nU9nUTSEIIoKM6XxXUGsgFBpaS0oQBIwf/yBGjBjts+3jj7/A779vxYcfvodrrx2ABx98JORj+ivj\nFx8fj08//Ro7dvyOFSuWYePGn/H008/7LXHeWGtgKGmWFgbLsNA6kiHEmGB1RGdwiSCIpsG7vHn/\n/gOwZs33sFrF+nXl5WWoqqpCeXk5YmJicMstt+HPf74fx48fc33f4KouG5ju3a/E/v37UFtbA47j\n8Msv63DVVX1RU1MNnucwZMhNeOSRx3HihHhMfyXOm4JmaWEAgA7xsKMUJaZKtE/KaurmEAQRJXiX\nN584cQrOnDmDxx57CIAoKM899zIKCs7hv/+dB5ZloFZr8OSTTwMARowYhSefnILU1DSfoLdU8jwl\nJRV///skTJ78dwDAwIGDMXjwDcjLO4HZs1+EIPBgGAaPPTY5YInzpqBZljcHgH9+tQBFqv24r8ME\nXNehR1M3p8mg0s1u6F64oXvhhu6FmxZZ3hwAWunECy+urWrilhAEQbQMmq1gpBrEyXtlJkqtJQiC\naAyarWBkJoiT96rqyNQkCIJoDJqtYGS3EgXDaCfBIAiCaAyarWBkuMqDWLmmSS8jCIJoaTRbwTBo\nYwEBcIDW9iYIgmgMmq1gsAwLhteCY2x+Z0wSBEEQ4SWiglFcXIwHHngAd9xxB4YPH47PP//c736v\nvPIKbrnlFowcORJHjhwJ+fhq6AC1A1Ybre1NEAQRaSI601ulUuHpp59Gt27dYDabMWbMGAwaNAid\nOnWS98nJyUF+fj7Wr1+P/fv3Y9asWVi2bFlIx9cyMbCrjKgx2xCra7aT1gmCIJoFEbUw0tLS0K1b\nNwCAwWBAp06dUFpa6rHPhg0bMGrUKABA7969YTQaUV5eHtLxdaweDCug3EiZUgRBEJGm0WIYBQUF\nOHr0KHr16uXxeWlpKTIzM+X3GRkZKCkpCemYsepYAEC5iVbeIwiCiDSN4scxm82YMmUKZs6cCYPB\n4LHNX8BaKtAVjLS0eCTHJeBcNWCFrcE1UpozLfnavaF74YbuhRu6F+Eh4oLhdDoxZcoUjBw5En/6\n0598tmdkZKC4uFh+X1xcjPT09HqPW1ZmRCyrAwAUlFe02OJiVFjNDd0LN3Qv3NC9cHPJFx+cOXMm\nOnfujAkTJvjdPnToUHz33XcAgNzcXCQkJCA1NTWkYycbxCUQq620tjdBEESkiaiFsWfPHqxevRpX\nXHEFRo0aBYZhMHXqVBQVFYFhGIwbNw5DhgxBTk4Ohg0bBr1ej9deey3k46fFiYJRS+VBCIIgIk5E\nBePqq68OaV7F888/f1HHz4gTLRETRxVrCYIgIk2znekNAGl6UTDsrJFmexMEQUSYZi0YerUOKl4H\nQWuGxeZs6uYQBEFENc1aMABAJySAibHCaLU1dVMIgiCimuYvGEwcGEZAhYUm7xEEQUSSZi8YWlYL\nALDY6pq4JQRBENFN8xcMlQYAYCLBIAiCiCjNXjBiVC4Lw06CQRAEEUmav2CoYwAAFoe9iVtCEAQR\n3TR7wdCpRQvD6iALgyAIIpI0e8HQa0QLo85JabUEQRCRpNkLRqwsGOSSIgiCiCTNXjAkC8NGgkEQ\nBBFRmr1gGGLENTFsPAkGQRBEJIkawbBzJBgEQRCRpNkLRrxWFAwH72jilhAEQUQ3zV4w9FoxhkEW\nBkEQRGRp9oKhkyfu2cDTmhgEQRARo9kLhlR8UGCcqKihyXsEQRCRotkLhlRLCiyP8xXmpm0MQRBE\nFNPsBUPFqsBCBUblREmltambQxAEEbU0e8EAgBiVDlA7YKVlWgmCICJGVAiGXqUDo3LQut4EQRAR\nJCoEI1YTC6idsNhoLgZBEESkiArBMGhiwTACzDaKYRAEQUSKqBCMeK0eAGB2UlotQRBEpIgKwTBo\nDQAAq9PSxC0hCIKIXqJCMGLVooVhJQuDIAgiYkSFYBg0sQAAG08xDIIgiEgRFYIhWRh2gZZpJQiC\niBRRIRgJMfEAAE5lAcfzTdwagiCI6CQqBCPLkAkAYPUmWG1cE7eGIAgiOokKwUjQxkMlxICJNVJ5\nEIIgiAgRFYLBMAwMSAITY0GNlVJrCYIgIkFUCAYAxLFJYBig1FTR1E0hCIKISqJGMOLVCQCACmt1\nE7eEIAgiOomoYMycORPXXXcdhg8f7nf7zp070a9fP4wePRqjR4/G+++/f9HnStSKglFlI8EgCIKI\nBOpIHnzMmDEYP348ZsyYEXCffv36YcGCBQ0+V6uYVoAZqLHXNvhYBEEQhC8RtTD69euHhISESJ5C\nJlnfCgBgcpJgEARBRIImj2Hk5uZi1KhRePTRR5GXl3fRx0mLTQIAmDljuJpGEARBKIioS6o+evTo\ngV9//RV6vR45OTmYNGkS1q1bd1HHitfpIDjVsKmonhRBEEQkaFLBMBgM8ushQ4bgxRdfRHV1NVq1\nalXvd9PS4j3eMxo1hF0aOFQ2n23RTku73mDQvXBD98IN3YvwEHHBEAQh4Lby8nKkpqYCAA4cOAAA\nIYkFAJSVebqerDYnwGnghNlnWzSTlhbfoq43GHQv3NC9cEP3wk1DhTOigjF9+nTs2LED1dXVuPHG\nGzF58mQ4HA4wDINx48Zh3bp1+Prrr6FWq6HT6TB37tyLPpdOq4Lg1EBgONg5B7QqTRivhCAIgoio\nYMyZMyfo9vvuuw/33XdfWM7FMAzUgg4CAIvTAq0qMSzHJQiCIESaPEsqnGjZGACA2UH1pAiCIMJN\nVAmGjhVX3qutMzVxSwiCIKKPqBIMaeW99w4sgpOnMucEQRDhJLoEQxsjv66xUVYEQRBEOIkqwcjW\nXSa/5gRaeY8gCCKcRJVgZMalwlnSFgAJBkEQRLiJKsGI02sBgQEAcDwJBkEQRDiJKsHQaVUQBPGS\nyMIgCIIIL1ElGDEaldvCIMEgCIIIK1ElGFoNC0gWBrmkCIIgwkpIgvHjjz/CZBInw82bNw9//etf\ncejQoYg27GIQLQzxkpxkYRAEQYSVkATjgw8+QFxcHA4cOIAtW7Zg1KhReOWVVyLdtgtGq1EBPAW9\nCYIgIkFIgqFWizUKt27dirFjx2L48OGw2WwRbdjFEKNhKehNEAQRIUISDIZh8P3332PNmjUYOHAg\nAMDhcES0YReD1iPozTdxawiCIKKLkATj2Wefxdq1azF27Fi0bdsWZ86cQf/+/SPdtgtGrWLBgoLe\nBEEQkSCk9TD69u2L999/X37fvn17PPfccxFrVENQs+IlUdCbIAgivIRkYbz++uswGo1wOp34y1/+\ngquuugqrVq2KdNsuCjWrAgDwZGEQBEGElZAEY9u2bYiPj8eWLVuQkZGBdevWYfHixZFu20WhUYmC\nQRYGQRBEeLmgiXu7du3CsGHDkJGRAYZhItWmBqFWiS4pypIiCIIILyEJRkpKCp599ln8+OOPGDRo\nEJxOJzju0uyQta4YBgW9CYIgwktIgjFnzhx07twZc+fORWJiIoqLi/HQQw9Fum0XhUYtuqTsHK24\nRxAEEU5CEozk5GTcf//9MBgMyMvLQ2ZmJsaMGRPptl0UGpdLyu4kwSAIgggnIaXVHjx4EFOmTIFW\nq4UgCHA6nXj33XfRo0ePSLfvgolRawAANuelN7GQIAiiOROSYLz66quYPXu2PMt7+/btePnll7F0\n6dKINu5i0GlEwSALgyAIIryE5JKyWq2yWADAgAEDYLVaI9aohhCjcbmkKIZBEAQRVkISDL1ej+3b\nt8vvd+7cCb1eH7FGNQS9RgsAsJGFQRAEEVZCcknNnDkTTzzxBLRasTN2OByYP39+RBt2scRqtYAd\ncJBgEARBhJWQBKNXr15Yv349Tp8+DUEQ0KFDB9xyyy347bffIty8C0cfowZMgIMnwSAIgggnIQkG\nAGg0GlxxxRXye0EQItKghhKrjQEAOCiGQRAEEVYuek3vS7U0iCHG5Ta7RGeiEwRBNFeCWhh5eXkB\ntzkv0RhBrEswnFQahCAIIqwEFYxHH3004LaYmJiwNyYcGHTiPAwnxTAIgiDCSlDB2LhxY2O1I2wY\nXEJG1WoJgiDCy0XHMC5VpBgGVaslCIIIL1EnGFqNCgLPws6aYeeonhRBEES4iDrBAABVdTvwagt2\nFu9p6qYQBEFEDREVjJkzZ+K6667D8OHDA+7zyiuv4JZbbsHIkSNx5MiRsJzXYGsDAKi1G8NyPIIg\nCCLCgjFmzBh8/PHHAbfn5OQgPz8f69evx0svvYRZs2aF5bx6jQ4AYHPaw3I8giAIIsKC0a9fPyQk\nJATcvmHDBowaNQoA0Lt3bxiNRpSXlzf4vAaXYJjtdQ0+FkEQBCHSpDGM0tJSZGZmyu8zMjJQUlLS\n4OPG6UgwCIIgwk2TCoa/elThKDkSFyOWXrc6bA0+FkEQBCEScvHBSJCRkYHi4mL5fXFxMdLT00P6\nblpafMBtrVNbAeWAE46g+wXCzjnwxf4VGNbperRNzLrg7zc2F3ON0QrdCzd0L9zQvQgPEReMYFVt\nhw4dii+//BJ33HEHcnNzkZCQgNTU1JCOW1YWOANKzbMQBMBkswbdLxCbCrZh7YnfsOXMLrxxfXgC\n8ZEiLS3+oq4xGqF74YbuhRu6F24aKpwRFYzp06djx44dqK6uxo033ojJkyfD4XCAYRiMGzcOQ4YM\nQU5ODoYNGwa9Xo/XXnstLOeN02sAXgU7d3FZUmaHuPysyWEOS3sIgiCigYgKxpw5c+rd5/nnnw/7\neRMMWoBTw666OMHgqQ4VQRCED1E50zsxLgYCr4JDcJcGKbdWYMf50GZ+8wIPAGCZqLw9BEEQF0WT\nBr0jRaJBC3AqcIJV/mzW728AANrGZyMrLjPQVwEAnCQYuDQXiSIIgmgKonIIHaNRgRXUEBgnzhkL\nPSrXStZDMHiQhUEQBOFNVFoYAKBmtHAywOu75uGWy26SPw9lnQy3S0oVsfYRBEE0N6J2CK1hNfLr\nXcX75NehLN0qCYaKLAyCIAiZqO0RdSq9/Fq5XKuDr3+NDCmGEY5Z5wRBENFC1ApGK1WK/FoZtwhl\nrW+BLAyCIAgforZHTNG6S4ywrPsyQxEMjmIYBEEQPkStYGTo3amzVoc7vdYRimDwlCVFEAThTdT2\niK1iDXCWtwYAOBWZUaEIhtMV5yCXFEEQhJuo7RENOg0c57r4fB6KS0oKjDMkGARBEDJR2yMaXAUI\nvXGGkCXCp1jIAAAgAElEQVRld4kKS1lSBEEQMtErGDo1IPheXiguKYerym3gwuwEQRAtj+gVDL0G\n4C9OMOwuK4T3M8nvaOUJvLN3AaxOq882giCIaCZ6BUOnBsAAgqdb6UJiGP7KiLybuwgnqk9hx/m9\nYWknQRBEcyFqBUPFstDHqHzcUiEJBud07Ru4jIhADiuCIFoYUSsYgJgp5R34DqU0iCQqoRQqJAiC\naClEvWAIvLdLKvTig1yQUuhkYRAE0dKIbsHQqyF4Bb5DKz7Iefz1i0CCQRBEyyK6BUOnkWMYrOtS\nL6SWlL8sKQmSC4IgWhrRLRiK1NoYVgcgxFpSLsvCKXAQAlgS5JIiCKKlEd2CoZy8x4kLKtVnYQiC\n4FEOPZQlXRsbk92M8+aSpm4GQRAtjKgWDH2MGoJLMBw2FipGVa+F4S0Qmwp/xz9/m4kamzFi7bxQ\nntn2Kl7ZMScka4kgCCJcRLVg1Nk5MBobAMBm0ULNqOsNentnRi0/8T0cvBMHyw97fB7IVdUYSFaS\ng6s/gE8QBBEuolowEmI1YLRiCQ/epocKGtQ564J+hw+QGcXg0itESPNECIJoTKJaMG7qmw2GFS0B\nwa6H4FTDytXBZuewcPVhFJabfb4TbO6Fkksh6B1KxhdBEES4iGrBUCmWZtVy8bDZWFiddVi78yy2\nHy7Bf77yrQcVapC7KV1SEqFMQiQIgggXUS0YADD96okY1u5GXNHqcjhsLHiBh91Vvtxo8YwBcDyH\n1afWhXTcxnIH8QKPUkuZX4FyCmRhEATReES9YHRMbI9Rne9AWis9BE4NAMgXDoBtVeLjVPr9/C5s\nLdoR0nG5Rhrdrz/7K17c/h/8fn63zzZySQG1diNKzKVN3QyCaBFEvWBIxMaoAac4F+MUvxsxV+zz\n2cdo941pSHjHLEKNdTSUPSX7AQCHKo74bCPBAJ7e8jJe2vFWUzeDIFoELUYwDDqNbGHIsKFbCd7x\ngkshQ4kEw82lEFMiiGinxQhGrE4NeAkGE+NtUQTudLznbzS6YPiLYVDQW+ZSEHCCiHZalGAIrvIg\nEqw+sAvKG+/OubE6a4YJPP+Dgt5uSDwJIvK0HMGIUfsYEIzOgnOlJuUn8ist6ykuzqa2MPzQWKVB\nmoO7J9CES4IgwkeLEQxlqXMZlQOzFu9UfODuGPVqnceuDq/RPMc3blFCf112fTGMrYU78N/cj33m\nlhypOI5lx1eFJARFpmL849ensOP8ngtpbqPTWEkIBNGSaTGCEatTgyvPgiP/Ctj+6A8AYNSi1eCv\n89ep9R7vvTtnrpHdQf5mltfnhsktP4Q/Ko/B7LB4fP7e/o+QU7AVxZb601G3nRcF9ZvjKy+gtY0P\nJQAQROSJuGBs2rQJt912G2699VYsXLjQZ/vKlSsxcOBAjB49GqNHj8by5csj0o5YnRoAC2dxR/B1\nBgAAoxI7mYpam+/+3hYG5y0YTT+ira+TtDnFCYoNcZ9J1gnLqOrZs2m5FP4fBBHtqOvf5eLheR4v\nv/wyPv30U6Snp+Oee+7B0KFD0alTJ4/97rzzTjz77LORbApiNIoOz5UtpTaYwPTKwYdbq9Cv9ZUo\nV1nlXXReguEdYG6siXvBqC/obeNEIQxkiYRyDbzLbaViLm1j9FKIKRFEtBNRwThw4AAuu+wyZGdn\nAxCFYcOGDT6C0RhBVYZhMLhnazh5HtsPl0DgVECMGSyAYuzAD9U74CjqCE2WuL9PDMPHJdVIWVJy\nIF68R8p4RL0WhkswuAD72UNY31wKJrOXumBcAgJOENFORHuBkpIStG7dWn6fkZGB0lJfv/n69esx\ncuRIPPHEEyguLo5Yex6+sxsevqOb+MZ7Eh8ANq5Kfu0tGD4xjAvooOqcdTAFmEXu5J04VXMmYNFD\n76Ra7gIEo04SjADHDmU9DU52SV3igkEWBkFEnIhaGKFYDjfffDPuuusuaDQaLF26FE899RQ+++yz\ner+XlhbfoLapoAUPz9iFKsEtGMnxCR7bzJwJKSkG+b2d57D5UDEcTh7jhnUJeq57v5kBAFg27gOf\nbR/vWYp1eTl4/JrxuKnjdT7b1WrRlabRqpGWFo86h3s9D61O3BboXkgWRHxiDNKSxX2U/xN9vLre\n+6g96Tq/StXgex5JEhJFgb+U29jY0L1wQ/ciPERUMDIzM1FUVCS/LykpQXp6usc+iYmJ8ut7770X\nb70VWl2gsrKLXzJ1wfQhmL//CE7XBj6GYPcM8p6qysdHO5a53xdV4cgffwAAbr4qK6Tz+mvzznNi\nrah9547gyviePtudnDjCt9ucKCszwuJwx1lWHlmLUd1uhana11LgBR42pyiI5ZW1iOfEc9tclXoB\noLyqFmWa4Pex1iJaRoLANOieR5qyylp0TmnY76K5wws8BEGAihXFvSXfCyV0L9w0VDgj6mfo2bMn\n8vPzUVhYCLvdjjVr1mDo0KEe+5SVlcmvN2zYgM6dO0eySQAArUaFWI0+6D7eLikA+CU/x/2GcY/U\nrTYnVm05jYIyk8936kNy9YSa5ePtevnm0Gq/+9kVwqAMeludVr/7BMLqWqHwUgx6K914NHEPeGPX\nfEzNiWzyCNGyiaiFoVKp8Nxzz+Hhhx+GIAi455570KlTJ8yfPx89e/bETTfdhCVLlmDjxo1Qq9VI\nTEzEa6+9FskmyfgTBCU6dUzwAzDuzmrN72fx4/az+P1wMV7/+8CAXymrsiAtKdbjM6kjrr/DEwXK\nWzACxSGk+IX3d6yKJWrrW99cuf+lOM9BKRj1TaS0OutgcViRok+KdLOajAJTUf07EUQDiKhgAMAN\nN9yAG264weOzKVOmyK+nTZuGadOmRboZPsRpDEG3/7b3PBDMCFFYGL/uKwQAlFZZfXZTjuif+nAb\npo/rix4dkuXPWFkwgge9pbN575cS678DVLqeLE4rquqqkaRr5SEY9hCC3lL7bSFYI42NRwJAPSnG\nL/7+JowOE+bdOBtq1v2zLzGXotRajp6p3SPWzsamOZRyIZonl56foZFI0rUKuj2voB6fp8LCsNrc\nnVVRdSU2F/4OO+dAta0GT26apfiOgEOnKzwOU59geOM9klZ2fkqk+AUAfHzoCzy7bTZMDvNFWxj+\nBEMQBNQpjtfYKDPV6nPpGR2iu9A7PfqlHW9hwYFPfWbD+8PqtOLzP75BhbXyIlrbeDjJPUdEiBYr\nGCm65OA7CAycFZngqlMx54aX8UC3cR6bGdZ/B/W/wxux9NhKzN37Ac6bSzw3MoJP9VlVkBhGjckG\nh9Pzc2/XVSBXkY3znb2+pXAHamw18vtQ0mqlOIeDd/iI2rLjqzB90/Mot1b4+2rEUbraQk1zDnS/\nQonnfH9yHXYU78HHh74MrYFNxKXoPiSig4i7pC5Vkr0sDN6mAxujHC0zcJy8CgCgZbXINHhmd0Hl\n+VBe2y0duXnlyCs7DyQA+cYC6FRecRKWh3e1cqnkhj8LY+p7WxHT3Qw2zl1LyltYvEuWSNT5EYzV\np9Z6vA9l4p5ytGrn7B4z4DcVbgMAnKw+g1R9Sr3HCjceMYwQR9UNEQyTy0qxOOu3RpoSEgwiUrRY\nCyNZ5+n7F2zeAQu3H7isxoo4lWc6GqPiPPbpnJ2I7NQ42AWl6HjXU/f1LdfvkmI8DiUJRvuEdgAC\nlzgPJeZQn2BwPOfRrkDHbKpJc6FaGMoONFCZFBtf//2S7r3qEq+rVZ9gVFirkF9b0EitIZqaalsN\nqhWehYbQYgXDO+gt2Dyzl27sm4Ur2ohzRJ7+cDvmfn3U9yAKK6NVXAwMejWgltJQVT6dE8PwqLN7\nfsa6TI5QO13JJRWj0gIAnAHcSsoYRiCCuaQsDisWH/Z0vfhzcwFNt3iRMp4T7P4p4xPK4LiHGDpD\nEAzXdarYS08wlIHu+mJTXx/7H+bt+zDSTWpSjlXm4VD5kaZuxiXBM1tfxTNbXw3LsVqsYDAMg8d6\nPSi/97YwDDo1+nV1u6GKyn3dEIzKgZf+ei1u6puN3p1TEafXgNG4C/6dr/QKnDMCrHWeoz+3heFp\nfQRKE5U6Rq1LMCQLw8bZsf7Mr6hzCUWgzl1JMAtjw7lNyC075PFZiaUMB8oOB2xTY6M8b7BAr1Iw\njledxObC3wF4phjbQ7IwXIJxCVoYSvGrb2GtKlsN6jibj1X2R8UxbMjfFJH2NTbzcxfigwOfNHUz\noo4WKxgAPFIp777OM62ydWoskhM8YxBclVccQ+1Em7Q4jL+lCzRqFgadCtCIHQ/DAHvzvIPePCw2\nz4dZFSCGYbJ4duZyDMMlJLKF4Xrolx9fhVWnfsL3p34C4D+GIZERmwYAcChcTA7OgV/yc2C0i356\nf4Kz4MCn+PDgZzhZfcbj86bymXtM3Ati5Zgd7jpey45/h6XHVsLssHgISSguPOmeMD4Vvpoeh4fb\nLfj/w+K6bm9h+e/+j7Ei7wd50EEQ3rRowVCSluBZO0rNAu3S46BWsejdKQWPDu8OR35X8DYdeKvo\nzrqqS6LHdzQ6p0dQ21jnNS+DEXDwVAXyCmtQWi1uqzWLwuAtGEYvwTBZXYs9ebmkpIf+ZM0ZAJBT\nPoN1gFL8xq7oMH49twUr89bg08Nfi00N0imWeC28FErAOBJ4xDCCpNWa/KTMztj8AnaV7JPfhyIY\nUgFJK+c736bYXIKcgm1NNgdC6WoLZmEIgiALRqC5K6GkGEcLds7hMVcq2gj377HFZkn5Y1BWf2wt\n2gEAiFHHILWVHv+degM0alFX+1xxO5Zu6AxD+3P4teRnDL4q1eP7lapTgKKfL60xAcpYOcNDEIDZ\nS/YgwaDFv/7cB+fKjFAleqbLWm1O5Jd6urOksiOSsGjlGIYTJocZJRaxxIqaVcPisOCcsTDgdcZp\n4qBlNXLHAQCVtmoAwDmT7/c0rMbDL27xesCCWTNKzhmLkBGbKre9oXi4pIJYGJYAHeCPp3+WX4ci\netJcDmU9L4k3d78LG2dHRmwauiZfXu+xwo0zRAvDwTtk951yP2XHYnKYomZGfH3p1s9ufRVmpwX/\nvfnNRmpR46J8RsIhHi3ewojXxgEAErTx+EvXu/HMtdNwW/uh6JHSFQBksQDERZgm3NYV2UliSm4d\n79lRnnEcgsCpwBld2707IUWWlIk9j+NFJfJndo7DkvXHsOdYGWYu3I6PfvAM2Dk5Hk6Od1sYrGRh\nOPDt8VXyflW2Gry+az6OVeUFvOZYjR7JuiRU1rmr80rBd3/ZWglazwyxyjpRXCQrxBrC5L2ztefw\n+q53sPDg5363VxltKKu+sJGeMs7zw+l1KLf4n1Bn5epvX7CYjyAIKLGUyddpddb5PHyShVJsrn/Z\n20igFMxgguGRAKDYz6xIFfZnkYWTs7XnsOb0zxGzxpTHrS8T0NyAFOkySwUWHVyCKtfzcCmiHOiF\nI9bY4i2MGf0m40TVKXRKbA8AyIrLRFZcZtDv6F2FCyVTluM5LD22AkauGoI1EXCIdagY1hWgFmJh\nZyxol2HA2VPA4GvisYdZi2+LDgCMmJ1VXmvGmX2F+HVvYMvgx9/PIqOD+ANgoQEgPvSVio6y0lol\nj4SVZBkyUWQW1xqJVeuRrE9CsaUU5dZKrMxbA6tD7PAkwVC6pBK0caioc5+jwjVRT8Wq4OSdIc32\nliygI5XH/W6f/t+tAIDF/7653mNJeE9iXH74R9zdfqTPfqH45INZGHtKcvHJH1/L7zmBg513yG5B\nQExe4AUelbYqf4eIOMpFsoK5pJTW4cZzm+HknfhL13tQY6uVP1fGfCLBm7vfBQB0T+6CDontwn58\np8e9qH+uESA+wxea/fbl0W9xovoUVAyLh6+874K+21jYudB+F6HS4i2MZF0S+re+2mcGdjCk9b6l\nEefRqjxsO78LACDYdRAE1211CUasVtz/nps64L1/3oD4BHEExGhtYFwlRjiNCWCD/UMFfLflNL76\n5RgAYOfhcgDixD2T3YxkXRK6Jl3uVywA4OqM3vLrU/kWJKpFK2jx4S+RW3YQx6rF43I8L8dLJOK9\nLAyp85dmqQeyMEwOM/aUiOXbA5Uw8b7GMkvgWeMlljIsObJMdjF5xy10AVxddSFZGIEF41DFMZ/P\njHZPl2FyjHg/KwOMNo9UHsdHB5eIAl9XhUkbZ2BL4fZ62xUqTiE0C0Ppnssp2IatRTvh4J0egmEK\nQTA2FfyOY5WBrdhQECDAyTuxMX8TamzhKz+uFIlgqeOemWWe+31z7Du8vef9oOepC1I2JxR4gceB\nssMRTRpRXlc4ztPiBeNi0KtFC0N6+DSKkYmejQMEl/i4BEOvliwOAbE6NZLjFSm8CjeVpp17rsdl\nGfEY3Ku1z34Wm/jjLCqzQsWo4OCdMNqNiNfGITkmcLmTGJW7+u7BE0YcPyUe52ztOY/9HByHqe9u\n8VjqL8HltpMos1bAZDfL0xKVMYw6uxN1dvGHOW/vh1h8+EscrTwBdZBUVMmFoG57DC9sfyOgFfLR\nwSXYfn431p39FYCviR2r9V8tMpQ5KcqHnuM5HCz/Qy7O6G8sccbrvsVqREuxIkCZlPdyP8K+soPI\nNxZgvys1+etjK+ptV6g4L8LCkKhz1qFaaWEEWB0SEDvg38/vxjfHV2J+7sKQyssEghd4bC3aif/l\n/YBFAVyVF4PSDRXMwlC6IVfk/YBFB5fI7w+UH8bJmjNBxUD6/VucVvzvxOoLDp7nFGzDhwc/w/IT\n/pcoCAeOEO9FqJBgXARS3ENKQVVWfb217+W4tovo0hJngwN6jdhZSx1cWqJSMNyjHFXKeQzskYGJ\no67ErIeuwd1DFGufS8Ii/RVYsGBRW2eEU+CQoI2Dw+w5+VBJnMbd6QucBtZa/6NxhhXAXnbAY/Tl\nbWEAwKmaM3JnoXRJTXx7E6bM2wxe4GUXmNlhCVpcUSreqE4XO+HD5X4mScI98pUsGu+AZqBg+rmK\nwD7m7DhRlJWdx8Zzm7HgwKdYmbcGgKd7Ttr/lCsrTULqpPy5v5QWmCAAWpUmYHsuFs+02sAdg9lP\nwP7fW17C/rKD8vtgFsaPZ37BF0fcC4ntU3zvQrFxdnkGcr4xfDPPleVygsUwlMkLW4t2IrfsIIrN\npR7t8rYk/XGq5gw2ntss/15C5axrtv3hCv+/93CgFHRySTUR0ixx95wF9yhEzaoQqxM7LrVa7CSl\ntTUkF4pe5+6ADLHukTdvTMLfhveQJwwmGrRITxbdWR2yXJ22LBgMnE5GDvSWlfPYsivwjztGUMxs\nd2qgsicG3FedVugRRFXOio9jxJpRO4r3yHNDvF1STk7A3L0L5PecwHn8WJXBdgAwSi4wTrwXoUyi\nA3wD9NJs7RqTDUvWHYOlzoFqkw1nSgPHFeQUY8WDddbVeW0q3IbcskMe7spOiR2gYlTIry2A1WmV\nvyfNafE3Ij1ZfVp+beftcsJCOAk1SypQHaxDik4rmGB4lxTJU1zbhWLjbBdcrTkUPF1SF2Ztvbzj\nLfx6bov8vjaIYAhepX9MfiyzIlMx9pX6F1XJMyH9vywOK7459l2DqiGXmEux/syv8v20k0uq6VGz\nasSq9XK8QBkwbROXJfv2e14udsrSyFcaESs7z1idCq1ixP3aZvq6VGK00r+IR0KsBnBZLe0zWoHn\nGFmEzhU5INTF+XxfQnC4JyEKPAsVFxd0EakCxahco1LLbazKF91eylngUozA5nCLjHIE/sXGQ/hh\n+0n5/XPbPBfJkiYpCrxLMAK4Obw9QzUWT6GycXYsP/49nt3+MjYVbcXKLSfF+SyqwA9Ksi4JDBiY\nHWbwAo8jFZ7usEUHPwerOHOsRo9WMYk4XZuPJzfNwqeuYLjUZn9ip3Rf2TmHd4WxsBCyS8qPhSEh\n/U4tF1Cy/nTNWZ/PVp9ah8m//tuvi0b5v7VxdlkwvDvfhhCqGybQvZAqAQBA7QXEVjR+LMdXd76N\njw75z6RSueJ60v9u7dkN2FS4DYsPfxXyOb15ffd8rDr1E/5wxd3IJXWJEK+N97EwhrW7EV2TL5dn\nb0sTo6RsGs5P/jsncFCzamhYNTR+PBXSSIETeKQk6sGoxXON6H8FILj/fYJDi9v6XBGwvTaz4uBO\nDVgwaBvfJuD+J8vd2Vqck8Hw9L+g7vBAOM939D02Zwcv8O7Z6V7B+zrOiqLKwMvXykF2WTA8O90q\now3Lfs0DL/UprpjH4TOe8QK7045DFUfAs3ZoLzuKU/wu1JrtYLwE4/FeD8mv4zVxSNYlodRajg35\nm/De/o+wr/SAx/5KC0PLajwqHR921SuSihfaOLtPuqgyTnS2tAq/7PXtZC+UIlMxlhxZJrvAQg16\nB0sjzTZkQqvSXpAv/ry5xCdLbu2ZDeAFHqf8iIkyA8vO2SNSZsXOhSgYAa5TmQAQzMLwRs0ETuyo\nqPO1ciWxlP5fVod4H6v87Bsq0rMjXZuHS6oB8SYJEoyLJF5rgNlhAcdz8j/p8iQx5iA9BNLnsoXh\n6vyVI0CO56BiVNCyWr8/bqnGlCAI6N05Ra5VlRrXCjq1wrXhiEHrFIOniHDuhzG/xIwXBjyFbNMQ\nCPZYgAF6p/YIeH2M3v1gf/5THhb+7zQEcyJ8x/kitVYL3l1xANDYoO/3i+ex1A45xVjim40n8NbS\nfdifV+4jGJVmz07tvysPYu2OfNSaxftZXiM+WBVG8aEQHKIYHj5bCqvdAYFTga+LRbHqEEqM1T4W\nhkHjjvXEavRIj02F0W7yqZ2luAL5lValRasYt2BwAg+O5+SHkRd4n7pWysmQq7efxKlid4fgLS5O\njkdFTf0j/PdyP8L287uxqUAsMa8UiR/P/AKjzb9AW4NYGMm6JMSq9bA4rDhaeQKTNs6o1+UkQEB5\nABeKt+sRgFc5FltY1op38k6PtWeUz1GwVSUDueeU1s6FCEawisf+3ExS3Ez6vbBs4LVxLhZPC4Nc\nUk1GvDYeAgSYHBb5Hy9ZEtJD4BYMsUOT6h15+JsFJ1QMC41K43cugKCwMG7vfxlSUljX+eOg07qt\nBsGuQ2ZyLFSu+RlcTQrq9ruXxt19tBSp+mTE2kSrwmx14I+9Bqh5/5lFUsBePLjnz8Rxzncm88b9\nZ5FfYoIqsVz+7NrMvuILlcMjuA8A63adxR9nqjBv+QG5DIrkkjpTKprvuSfK8d6KgzhVVOvxXXOd\nuH+VSez8JKsnv7QKRrsJgtUAriwbYAScNp72sTCUgqFX6eTaWkWm837vhTK4rlVpEKd1f1/sMCs8\nOhnl/5EXeE/fNst5LL7lHfP46ufj+NcH23D6vOc1eyN1ZFKGmrdVsTV/N+ycHW/smo/NivRdY5AM\nqDitAVomBhanFd+d/BEAsO7sxoD7S6nSyjk6Skot5T6f1Sg6YBtnv6B09kB8/sc3eGXHHLnGWUNd\nUkqCxjC8xD5QRQEAKLX63gvJOpS8CNJAM5yCYQ/RVRkqJBgXSbxGypQyyg+9JBisK5gljTiklFZ/\nJRmszjqoWRW0XuU3JKQfDy9w0KhZpCQzYBkWsWo9DFq3hSHYY5CRrIeaEQVDcMQAzhjYT/cA8gai\ntNqKFZtO4dAp8eGuNtmx+0gljLuvB28LHMsQD8bg0eHd8ehwsUCj83xHOMtbe+xSXF3jOq9bxDJi\nxeA9o3YA3isUKiyOo/mukSgv/hwZlsO+E2WY/78D2Hu8zKc5VpuYumu0ivdXcLrOqXaAUfEQnFpw\ntWLZlmO2XWDUTgic+6fOcu77tvC7E+CtogAEyqhR/l9ES9DzwSu2eLZRKRhWZx0ECLK7gmF5D/H0\nniT3W26R2O780GYPS0LlLRh6tQ7Hq04i31iApcdWYHdJLowWO44WBp6JrmV0OF/mgNXhDuaXmMvw\n780v+V0/o7Xr/+s9ek5yWWD+Zr1vK9opv7b5qZh7MewpFef6nDWKrj9HiC6pGntgUZYyIWvtRtTY\narH65Np6Kxr4EwzJ7VTmRzyVrjzR0yAlAISv+rPTI+hNLqkmQ6q1U2otlwVDK1sYomBIIyzvGIZ3\nh6NiVNCoNKixG7Hx3GaPbbwsGOJfo92EOI0BLMNCrxCMB4f1RnysFipB7DylUTVX1hYPDL4O+hgV\n1vx+1qeMOsACXPBJdXcO6IABPTKRKqcDM4hTeWZZ5Ve4On2lMNh0EHgWsQbBQyDE/dzvD5wtAiC4\nM8BYDut2es5zUFJSZcX+vAq54xVc7We04ohRx+qRrE4DBAZmeI1+BeDpD/a43zq1yDsWPGvpWKH7\nYeccLK5OFydBSpaJdzFGKY7x+dqjWL5dLHAYKy3AxXIe114TYATr5HhY6hx44ZOd+GW3773wHpl7\n19Kqc9o8XEKfHP4Kp4trRGsvAJxdDTjVAAOYXPG5irpKGB0mfHlEnDOitKQyDRkAgJ/zc7DkyDL5\nNyp1kiavSaS8wONg+R/y82F12MI7ac3121aOqpceW4kKi/+YQHVd4EWFLotvCzWrRq3NhC+OfIu1\nZzdi+Ynvsfrk2oBVAcyKmAgvCHj3fweghvjb8mepKEvWmJ2WiFgYlFZ7idAmLgsAUGAsCuiSktCy\n7iypM7X5+PbEKo/tKpeFAQD/O7Haw+8qPaBSR2+0m+XRj+QSiFFpcX3PtgDcJUOULqWMZD06tPas\nxqtEUAhG2/g2MDCey9cmxYsWkkHv3q+Dqg+cFa3hrBDnnJQbxc5BGatYsek04NRAq3P6rIF+x3Vt\nkJKgAxNjgb7vr8joc0zuSBmVE8fPBR9hf/j9YberidNA4FkwWvEBTI1LRFK8HoLdbTk5Czu7rlUD\nTtkUpxrxbBKuSOoc8FzKTBmjmccVSZ3wxuBZGNnxDgDAqpM/eexv42z442wVck4cwU6H+L82Vrvi\nSSznIarKAKdUYBIA6uwczhYbkV9iwle/nMD5Cv+uJOlzKcGiZ2o3AKJlU+iaByNxrqpCtPYCYLOq\nIHDi78c7OJ5fVo3P1h7FiQL3/0XKyqu1G7H9/G6UWsphcVhlq8nG2cHxPH7dWwCrzYlauxGcwKF9\ngp5oOOIAACAASURBVPhb3ZtXHHQdk/rILzHi87W+cxi8rYrdhQd89gEQdBW6VrpEJGjjUWs3oszl\nTtp+fjfWnt0ou+u8UVoYNSY79p0oh80h/l/8xVKUc3bqnHXyIMDJcXByDRMNwY94kmA0IZJgnDMV\nyqmUsmB41aSJUQS939nru9KZaGG4R7nVNvdDyStcUg7OgTquTnaHSSM8pR/8xvbXAABu63a1/Jle\n67kYlA+u2IGW1eDf10xB70zPGEVGsmhZGHRud9O9N3ZFRu110NSJrh9G5cSgnplIT3Ffh7MqHQKn\nhsBy0Lg+7p7cBQDQ/8pUvPH4QDx0t3gfazVnkJHsmhGvtfmMhNtlxEGjEn+ukoBJGWOJOgPAq8Cw\n4kPSp0MWurZLguBw3ffaJDgrxPPI7isXAqfG+UoLLk8ILBjSGicAUFhqxQ/bzmDOV39g8Xdn/O5+\ntrQam/cXQRXvtm4cVnd9MaV4StkzgiDg+Y/d7poft5/F5+vcJUmKK7zcHa6B/t7jZeAFQR5JsjZR\n7MtrjSg0esZkfij6Fow6SMqthREtDH+onMjJLYLD6bYwtu2rggbuCgKv75qHf22eJcdV6pw2/LQ9\nH0vWH8dHPx3Egv3igkaJajE12+qou6BUT++YwSuf75FdeEq8j2nQxsLisOD1XfOw8ODn2FYklvGp\nCiIYBnUsErTxqLJV+0zGlFxw3u2p42zILTuE/NoCVyKHIKfB+0u3Vrqk6jibey4KI3gMHi4Gd4IN\nzcO4JIjTGtAqJhGFxiJ5wphkSXinCmoVLil/D4hoYbgfVGU9ImVareRzTYgR3RvXZw8AALR2CQcA\n3NbxBjzVbwqGX+Eu4hejVWFI7yz8c6y7nhQA9OjgKiXisjCkH1m7BK90W9dzEatztzG9lR4v/bU/\n+nfJdl2EExNu64qh/cTYRrZpiChEvAoO3o6+XcRzSQFnB28HyzAex9QoPEOsoQa9O6Wg7xVpmP/E\n9XjhoWsRFyvu0L2Dq/S2a7TMcFqPjLAEXRwG9cyUG945qxX6XS7eIz2rR5u0OGRX34K72t6FyzPT\nUVJpwffrA1s0jEIwNueWYMWmUzhbbISxxn/n+vnPf+Do2SpZsADI1k5KksYjhiH9rw+f9g0cl1S5\nXRw1Fs8OR9lXlVVbsf246LbatV+0hn7cfgJnazw7U9bg6xaRKisDgMkICLz/a2Jj6qBuc8xzXXpe\nBdbhntTp/du2cXbkl4jnPG07iHMmsT0qh8s9p+LgVEysyyuowRfrj+G3fYU+9czeWroP73zraSl4\nj8Klxcm800ctDisOVRzFOWMh9pcdwpdHv8X6M7/6WBgjO90uv47V6OUqzd712Q5VHMU3x1bKMTQA\ncvHSRQc/xxu756OwugxgBDByNWpfwVC6pGzO8LrnpGOFGs8JFRKMBtAmLgs1diPK6yqhYdWyZeHj\nkvLKkvJGxag8ROb387vk0YskGIIgyJN/pKBia0MGXrp5Oib2flj+LsMwaJfQBizDYtq9vXHnwMuQ\naNCCYRj06pSCuf8YhFZxYkd2W3+xUqjUmUkxlnZe8zPaxouioFb5/lx6tBU74lv6t4ZaxcqB41v6\ndUC7jDhkJsXDzjnkY0uCIVlFyrIbyh83a6jBwCsz8Y8xPRGn97IKwGPGn/ugXZY4ur3vph5gBPf9\ni9fGISMpFhkpomWk02rwyJ09wTIsumW3xkt/vRYzx/wJt19+Azq71m231/qWP5HvqZc7rX/3DDw3\noR9iNQGSBVgOtRbPQL+0BHDHbIOHu7DCWonCcjPeXrY/4PkBYEXOKZRWiVZGrcXunpMCYMfhEhTX\n1Hich9GbwTMBgvjn28uv7Sf6yq/PFtqgjg2c6aPJOu3hchR4Vs7K80ed04bjLheWco7Clj01oguR\n5WB1uNs4+8td2Li3EJ+vO4YPvjsEQRDwyY9HsGTdMfxxpgoHT1WgpDJw+77fKqYAe5ezt9itOFzu\nWUBylWtlSo+qzGp3XM6gjgXnDNw9bir83cPF0ye9F2IYd8ZhbuU+j1iVOFdJQG5eOTieh51zeIhI\nHWfzOJ7VduGuui0H3BalFOBWpjF7u04vBhKMBtA2XnRzVNZVedQx8rEwWLdLyh8qRgWHYvWz3SW5\n2OvK/JBiF5zAyyZ0kmLiWNe0znJ5C2+u7JiCu4d08giQJsbFYOb4qzH13t7o0T4ZN/fNxg0dPS2P\nrDh3BtT8G19DnNY9ipw0+kpMH3eV/D5BL25jY0SzXfLVJsXG4oWHrkVKnAECBNn8lgRD2i/QDGVG\nZ8aVHVL8XpeTd6LrZUnQ6DioGRWu6pSJtqnuhz1VJ1ozIy+/BQBwU9vroVVpMfmqRzCm83CPY/W5\nXAxcQ1BhSpfpSFSL91LgGXBVab4n59T489DL0aF1AjKT3fdFa0uFI190t0mxlaREdycpxYlsvN1j\nguahgkJ89pPSDy8AfmY9m6wOvP7lXgDAb/sK3fswYhVjyT0nCQYbJ7q6HIWd4Ch01yRL4LPAlSkG\nBAoXXU0N0DnG/b/1hxQnAgDwKhhNQWZoMwJqXbPxVayiq3HEiGVgWA4HTisyzBQd7JGzVfjw+8PY\nfOA8ck7lQttlF8A68dOOfJQHWTfF5uDkOSh3d74LALD7RAF2FfiuRQ8AYy6/S3596A/Fb9GmwoFz\ngRMvAICHYrKknYWzzu2eO2M+6XE9ds6OX3YXYP7yA/jqtwN4dturHseqc9o8LACr7cKsDUEQsPhH\n9xo6UhJErSN8VYABEowGIcUxAPeoH3BniUhIMYxApcfVrMrHjD5SeQIAwMMdw3BbGIHrQIVCaqIe\nPTuKnfH9t3TB2GuuFdvhEjoNq8bjvR7C9Ksn+cRjru6S7nZlQbRGtCotDpQdhiAIsq9WsqqkYL4U\nRI2VBUOaGe32Dzt5p3xt6rQiLM37xm/7pYfB7LDAoIkFwzAea1Ok6sX29UnviblDXkGPFLEjvyKp\ns89Kch2zEtCvSxr+b+jl6JKdgSSdmBwg2HXgje7rFATAfrInBJsBCQat6z66LYwJfYaDd3XWOh3Q\nKk6L63q5V2S8vpM4SdLBOaBz1RLjbTowGjvyCsWBABNjgf7adVBnncTrfx/gc93VJjtKq61Yt/Oc\nPDKWrB9GbYeKUePqTqIYMCrX78acgPuvHiofQy3o5fRlEQa8yfV74tS4odOVuFn1NzhL2vqcH3DF\nlyR4Vly22BK4JI00adIuKOam1BmgZrSAygmrXXE8ZSYd60Su8ycwhhrEdNkDVWIFVClF2LS/CDMW\nuEt3eMByWPDdIew4kQ8A6JjYAQBwpCwPjNbXJZSh6oCuSe543bY97s71fKkDzoLgKydyikHemUIr\nHA7FhD+hwmP+j5134Gi+6HbcVrJNHvnzZvH3dryoHFa7u42WOv/WYbG5FL8VbAXnyqKTz+e1pPNJ\n1+RQk90ENatGliH4Gj+hQoLRAC5LcD9UbeLd4uHdyUrvt5/f7fc48Zo4H/+lVE5CDnpDkJdRVVoY\n4UCr0uLpa/6J5wf8S/7sytRu6Jh4WQjf1aBnSjeU11Uip3CbbDlIQiFZXmaHBQwYxLpKw0uCoSyN\n7uAdiNfGyYK7p3S/h59Z6iTdxdosMLgKI2YrrCKdokZWfcvBsgyDiaN74pZrxP9lgk48HqOxe8RF\nBEsCuIps/N/N7uB4crwOzhLRrdcpuS2GDxBH8tdfnYSX/tpf7gCnXz0RE/7UGxpWDTvngJQNLdh1\ngNoOyVro3Uf8q2mTh/SkWEwfdxWm3NPLo70vf7oLVpsTLCvei/atXZMI1Q7EaWLx+IjeHgMWwRaL\nHm2y5ffdsjIxqKf7t9o+Mx4JRTfCunsYAAadshPx/+2deVwV57nHfzNzVg5nAQ77JqsiKosKLkQR\nCbihUEEbkza9as1iNKJZDPfT2BtTc29MbZO0ualNW5PWW1vbmn760U+allSjDcFoJGpQEzSKGAHZ\nZD/bvPePOTPMcEBRIQq833/kzHZmXs+8z/u8z/P+nonR/nBcSgSxCyNm4lAjtKNnEah0bcKB2LyQ\nZ30Yhro02M5NBukVNBen37odwv9znCMLcGqg5TTCPlk8R+69qAIvgfO5Bm1Cec9+WYbX1cYOT80B\nlsdn5xtBVN0gThWOfCp00Jyx7/gU59LDSyW0n2A0e674/se14Nv84LjiKYUjIvcwyk8rZWoYlgdr\nUK7zOPWVkH7tsvd4daLBOHSyGqdk3laHree9IITg9FeNuN5uw39/8ir2fvFX7PzgQzzx08NobhOO\n612t8kRVvbBWyd4OL9YAa5fyd3S7UINxB8g77iCvniyk3lNS/ckf6DihY5sTniHN/U+0jke0eQyu\ndtTBxbsU6zDEvHq5NzNYhBlD4Kfvv57GjciLnge9So/3L37gIYciehodjg6oWZW0XXxeeYaXg3dC\nzaolowIA+y/8Q1rcJV+kxhMeXc5ueLmrH4qyLHdKsv9EAO6OThYAJg4tCu6LQk5aT4U4J8/DcSkB\ntmM5MGq8kZuYBL1Kj+MNn0KrgSx7Tuh4NawGNt4OTkVAeAZwaIRaGyoHOJZBcoxyCiw0hMMvq18G\nF9ijydTR7cS89AhpiifIXweDTgVG7YC3xuD2tnqmRhKCQ+Bn6mlPs16PlfenwnI9Gd2fT0eo1YA1\niydImXI+Ri1iQk1ISwiAUSecp2I0eHC2ctoSALzcWQrTxgfC4owCf90ffKcyfXvt0gSEWg1wQmiL\nU+e63OfqBM0xWZxHrhIgBtcZlhfaCnAnORCwxkb856/+7TFxJ8ZXGLUdxKHBB5/UKfbbKtNhvzRO\n+nypxoGNPz2G7opZsJ1JUxwrZdO5biBFL4/nuFRwNgiG2KoR+gLWolyf4xINjCxxQBrocE5FfZc/\nHDqHr6624vi5a1i/cz/+96v/wVN7/iBNW50i74PRt+HVsv/D2x8cx4ef9coWY3gcOXkVzV1taGkB\nPjndf0bYrUANxh2SEiBYbrm30dtA9Bdj2Dj5MZSkFcNP7yv9EERxOwKC6/ZWRfC7svEcfLQWqZO8\nV/D38kOUKQLX7W3SQjRRuVM0EDaXHQQ9noetjykpAgI1q1ZU5/vo6lH84tTbwn638XQSJzocnSAg\nkocxzicWLMNibrjnSPhWmBqUgimBybBfHA9Xqy/CtNGwXxoH+4UJCPBR1htJTwgEwGBFttAJ6VQ6\nTAlMRrujA1931ErZc+J0mU6lQ7ezG3odAxWrgr9RmAoK8ldhTLARKq5noOHiXSit/lBow8gz7pgE\nD4Bgycwo6Tgn78SP1qSB4ZySDL3ObTC81QY8tUxIsxYHGeK6hycz85EcEoOiObEIdD9XsJ/wL8ey\neHTJBPga3OnbZnOfv+HiolQ8/70p8DXpoFa5f/O8crBkNnGIDjH1dK7u/RaDl5AGLXbyYBA/vu9p\nGDUjPA+r6wQXcBnahE+gDvvS80CGF9pIZQfH6wGeAyE9XgPf5Q1X3RjpsxhXInYvmLz0eHZFirTP\n18uIZx5I8UjD7o8N30rF+sw8lKQVY1PaI9CwGqj8lOtgGLcop+gpWfV+SAoRPFaGcyoMEKPtxNa3\nj+Hn+07BbrootIOswBqjckI38d+4pjqDcvu7ioC38F08/lZeBbA8DGoDHsicOKDnuBnUYNwhDycs\nx6bJjyNeNsKVT0k9NXktNJwGP83c5nFugN4qjTASfYVOJ94nBmatMEprsV33kH0eihrIg0GgQRgd\n17QJQntioF/DypMBWMmA/O3Ce/jHpYNSpyiiYlWSYRCpbDwHF++SYhdO3ikV3BFXW3upvbC78DUU\nxC68o+dgGRb/kbgCrvoIwKHDipgHhU7GqUVyrFVxbEyoGT8vnoU5KT1TPqKnea2rUbagU+u+Rz06\nHZ1w8k7o1RpMiRXiDQ8tiMJTy1MUWTPvnj+gWPWvHV8Ofdr7CJh6AloNJ02e2HkHCCecJxoMsY1N\nssJXa5NXIdFvHLLC7wMABPh4Yd3SSTAZNPDWq7FtzTT853d61u4I1xE6S51KA2+1AepeZXYtBh3G\nBAm/VdFgsHaD4hiby4as1DBpPQJxqcCxjJRhpvfiwTEcfHUWNNtasLZgAnpj1rsNmU89NGMqhe/x\nEbyHmRN65ubjI41YlhMBhgGCzT4AGCllnCVqyVtI9he+g3Tr8eO1M1GYGYONy5IwNsJHkrbZtCwV\n8eEWxaJWAODbTXDWek7VhvlZMDFaeJ9NGiMWx8zzOCY8SI+XH50uxXXWJa9GpNVtiDnl4lbOKFud\n7u4COK5v3S1GY4M1uhaKZAmGl1brR/pZMXPczaeXBwI1GHeImlMj2p2DLSKfkopyxwF6v2yA0rAs\niV2ADSmPYEZImlR7QszRH+sTi0luZdlYS/9zqncTUTdKlFUWn1deXW7FuEJp9AugzxWzalYlZZPN\nDElHSsAkt8hjh7Sa2cm7cNadFDBWtkKbY7lBEbMDgJhQoSO0WnSYkxqKB7LjoNV4SnHrtSrFd4oB\n94auRpmWmNCBe6n0sPMOdDm7oWbVUgfvQDe0Gk6hVdRbIkakjalX1qJ2OaQ4j2ggwtyDEPkUY7Ah\nEI8nrZRUAnoT5OsFL51yND0teCqMam+M840DwzAeK4XlnuDksYLhnheRg3mRWZg/Rgi021x2RAYZ\nERUidPrfnjMO6wsnSW2i0bmgZlUwaUxotbchJd4q1H2RxTbkXqgIq+1GTFq1lBoOACZvFcbHCt/j\nbxA8Ko4I3yMmU0QGGvG9xBVYFLgMq+6bDR+jFgumRSIiUGi7ZMcyeF9YiCBfL7Asg4mRymCxr7cR\n8yZ7LvIUi6SJzAnP8DhmyewIWC16RITo3OfoEC4aDFYZz2Hdiz6FMs3C70vFMSB83122JvQr7Fg3\nQ/qs1fasHwo0WqBTaW8azxsINxYRotwW/cUsfpC+CTaXHS8fex2AMptKzaqkeXipWJHbYIij3s+u\nnUZqwOAErwYbeQxHw6qlTlT+I51gTYCaVSEveh4qG8/ifK8ypwCgYtWSweAYVqon3mpvV3gYVS0X\noGI4D2M9WDz17RR02Zww6NT4Ts7YAZ/nrxeyz75ur8Wl60I2k9o9DSfGZq7bWxFkCJQMhqi51OXq\nCVwa1F6KHPpoc6RUX6Ld0aGQyhc1y/y9BA/oO+OXY1xQNCJ1Y275ueVMD56C6cFTpM8+WsELKIpf\nAgaMwoPJmBiMyEAjwgK8wTKxKHOvphZrS+j1AGzA3ORIcCyH02c10rN4qw0waY3gW3l0ODqx+Tsp\n2PX5WVx2O1ydzi6EegfjSi814a9RiRBrj0fjghPN7sSQaP9ATP3WePyr9Qucb70AG+lC8bIkRAYa\noWZVmJ84BX2xZpEyVrN4egxe6ZEeQ0yAFeGWQMCtWB9likSzrUURNxJ5dNL3sPPUO4izRONccxXO\nd55BCuLgbQDQIigla9zTUHERBrQ4HHAQb7R1OMB4CVO7S2ZG4WKZHo0QastzLAO5/23RmkEIQYej\nE1rZLXAckZIIArwFoxRqUAqG3g7UYHyDiFIeWeH39VtDAAAs7ikp0cNgGAYaTo2pQSn9nnO3iTZH\nSh2KXaHu2jNqFUeV88ZkwaI19Wkw1CwnqXWyDCtTDW2VgoI2lx1NthYEewcNSX1sANCqOWjVt17c\nx1fvCwaMpKAK9AgFymNPk6zjpfUtf/zir4qYBeAp4xBhDEOEMQwHa/6N5u4WmZClQ5LOFo2VmlVh\nSUIOrl0b3Bz8J5JXobajHskBnvPhDMNIo3QAUjXHvV/+FVMCk9HtErwq0auWDyTaHR0wu41Pq70N\nH1w9jMv2nhgFT3iYtSbJYCRZE/FZw+fSuSJO3olmt6Cgj9aMlAh/uOqm4fznFxBhDJNSyW8F+X3O\nH5ONtKAUWPV+aOpuRrfThrzoXOn5ezPROh6vZm5DafWHONdchX9dPoLJAUm40n4VWk4DjuXAEhYc\nw4FwDrAuHmqiAt+tAWtswv88NhV+Zh3iIr3ReFWs1kjAdxjBaLvAqJyIMIbBwTtwpukLhVy70ZtD\nu7t2jtifPJmy5pafvzfUYAwBNxNUWxqXd8P9ooch1hlgh8HMIcdyyI6cjb1fKIUV+3ODk/0n4lhd\nBc40KUuiBhoCEGWKxNnmLxHo5S9Ne8hLXIoSKfJU2nsFNatCiHeQx2gYgJTCCQi1QkRj6uSd+MMX\n7yqOtbnsYMBIMSwtp5WmPeTV29rs7ah3y6sHeCljLINNkCFQGvTcjLG+cbBozWixXcf2Y6+jobtJ\n4VH3HpGbNEKndt3Wio9rPdPP5ZlzsZYoWHQWHKr5N2plhZPsLofkYfi4g/STA5Lga/GGH26gpXYD\nQgxBWBh1P8b6xCHGMkbanhM5Z0Dn916T9crxnwPomc5jGAZmrQkttuvgCQ8dp0VSaDhOtzXBxrYB\nMEqGQPwtJIWNwXV7C6o7LiPMGIKGLiGlV66N5WfRYoyPBcebAbO7P+mrhOytMuQ90Ycffoh58+Yh\nNzcXO3fu9Nhvt9tRXFyMnJwcLF++HF9/7SkmNty4U41/i9YMHafD541CVkR/dRruNWaHzsCCMdko\nilsibZN7GHJ0Ki2eSF7tsT3ZfwJWTXgQK8YuRUboNMnD6KsmcvAAO69vmm+P/Vaf2+WdXoDeCj+9\nLzZP3dDvdWaGpkt/q1gOZnen2tDZk/Pfam/DsboKsAw7JOnWt4tepcP65O8DABrcAx957EXbayBh\n1vZ4GH1fr6fttCot/NwGQV6LxME7PBa3MgyD9LAUKZHkVmEYBgui7lcYi1ul3eGpNCz3IEXD2uHo\nhIpVYVyQkHFZ11mP67ZWnGxQrlL3NRjh6yU8X5h3iCRGKn9HCFzSlJTlNp+9L4bUw+B5Hlu3bsWu\nXbsQEBCAwsJCzJ07FzExPRlFf/rTn2A2m/H+++/jwIED2L59O37yk58M5W0NOao+Aty3AsuwCDT4\nS4v36js9iwjdizAMg4XROYptfB8yF3JK0orh5J041XAGHY5OWN3TKmJnKc6Tn7h2yuPcCX7jPLbd\nC0SbI/Fa5kt45vAPFenWetmUlDg1E24MQV50Lg5f+VgKXmtYNUxaE+aEZUDHafHP6kNI9Bsnqab+\nq+aIx3fGW2I8FozebQINAZgbPgttjnYcrf1UsU9uMNSsWurQ93/1jz6vZVB7warzRUN3E8waEwxq\nYVBWJyvSVN1Wg5r2r8Ey7G0biKFgbsQs1HXW41TDGagYzmMGQjRuLuJCp7NLSlr49ef/1+f1vFR6\n+GgtONd8HtHmSKl/+MMX+6Rjvmy5AAAesaY7ZUgNxsmTJxEZGYnQUCHtcOHChSgtLVUYjNLSUqxf\nvx4AkJubixdeeGEob+kbId4nBvMis5B8BwFqcdQAAMvi8wfjtu4KvUuQ9kacVpJ3rHLE/P86mdFc\nHp+PWWEz+jz+XoFjObx83w8VUxJsP/XQ542Zi/sjMrH+4HMAgB/P3goGDBiGwZKY+cgKvw9mrUmK\ne8lH4WaNCQ7eccfpxEPFt+IWgSc8jtZ+ijhZhp98hP1f05+FQe2FsT6xONdc5XENHadDRkg67gud\nhsrGL5DoN06q4d3bePKExxhThMdU0N3EpDHi0Un/geu2Vmg5DT68UqaQ6rDIpH4SfOMRY4mCXqVH\nl7NvzayxPnGItUQhK/w+cCwneeF9VQQ0qL0GdSAxpAajrq4OwcE988yBgYE4dUo5Uqyvr0dQkNB4\nHMfBZDKhpaUFFsu9417fKizDIq+PPOxboTBuMbScBkvjFkvu+nAkziIsMhvonG9veqeBzh+Tfc8b\nC5H+XlT59Ir8WLPGBC2nURoZ2WjZr9fiudzIrD7z/e81WIbFj2dthUrWHvLqdOLzPZG8GicbKlHb\nUY+/XXhP2r9ywoOSqsKMEGEhYn/TkfMis5DZR0rrvYD4nDd6Fx4cVwiGYfB40kqcvPY5ciIz8fTh\nHwIQ4ikTrAmI8xEMr5i+31+qNADMCp0+SHcvMKQGo3eBkYEcQwgZtFz64Yy/lx9WTnjwbt/GHWPV\n++GnmdskYcPbYWlcHv785d+wZuJ3pfUow5G04Mm40lGL2f28xFtnPHfD3z7DMJgckITj9Z9hScx8\nzAm7NzvGvui9TiHRbyzeu1gqZRkBgmFJ9p8A+ANTA1PgIi64iKtP48AwDMb5xOFss5BNtSw+H+lB\nqQodseGCWH9mTliG9P8fbY6UtNz+O+N5qFiVlHnWmyiTclGeWWPCjJCpmBqYgkDD7QX7+4MhA+nV\nb5OKigq8/vrr+NWvfgUAUtB7zZqe9K7Vq1dj3bp1SEpKgsvlQkZGBsrK+lGjpFAoFMpdY0gn+iZO\nnIjq6mpcuXIFdrsd+/fvx9y5cxXHzJkzB/v2CcGa9957D9Omeco6UygUCuXuM6QeBiCk1f7oRz8C\nIQSFhYVYs2YNXnvtNUycOBFz5syB3W7H008/jTNnzsBisWDHjh0ICwu7+YUpFAqF8o0y5AaDQqFQ\nKCODeyf3jEKhUCj3NNRgUCgUCmVAUINBoVAolAEx7AzGzbSpRholJSWYMWMG8vJ6BAuvX7+OlStX\nIjc3F6tWrUJbW8/K3xdffBE5OTlYsmQJzpw5czdueUiora3Fd7/7XSxYsAB5eXl45513AIzOtrDb\n7SgqKkJ+fj7y8vLws5/9DABQU1ODZcuWITc3Fxs3boTT6ZSOH2l6bb3heR4FBQV49NFHAYzetsjK\nysLixYuRn5+PwsJCAIP8jpBhhMvlItnZ2aSmpobY7XayePFiUlVVdbdva0j55JNPSGVlJVm0aJG0\n7eWXXyY7d+4khBDyi1/8gmzfvp0QQsjBgwfJ97//fUIIIRUVFaSoqOibv+Ehor6+nlRWVhJCCGlv\nbyc5OTmkqqpqVLYFIYR0dnYSQghxOp2kqKiIVFRUkCeffJIcOHCAEELI888/T37/+98TQgjZvXs3\n2bJlCyGEkP3795MNGzbclXseSn7zm9+QTZs2kUceeYQQQkZtW2RlZZGWlhbFtsF8R4aVhyHXT7us\nDgAACDZJREFUplKr1ZI21UhmypQpMJmUQmqlpaUoKCgAABQUFEhtUFpaivx8QXcqKSkJbW1taGho\n+GZveIjw9/dHQkICAMBgMCAmJgZ1dXWjsi0AQK8X5EXsdjucTicYhkF5eTlyc4WV0wUFBfjnP/8J\nQPl7yc3NHXELY2tra3Ho0CEUFRVJ2z7++ONR2RaEEPC8ssTxYL4jw8pg9KVNVV9ff4MzRiZNTU2w\nWoXaB/7+/mhqEkTp5LpcgNA+dXV1fV5jOFNTU4OzZ88iKSkJjY2No7IteJ5Hfn4+Zs6ciZkzZyI8\nPBwmkwksK7zSQUFB0vP2p9c2Uti2bRueeeYZSVajubkZZrN5VLYFwzBYtWoVli5dir179wLAoL4j\nw6qAEqFLRm5IX+0z0nS5Ojo6sH79epSUlMBgMPT7fCO9LViWxbvvvov29nasXbsW58+f9zhGfN7e\nbUFGkF7bwYMHYbVakZCQgPLycgDC8/V+5tHQFgCwZ88eySisXLkSUVFRg/qODCuDERQUpAhS1dXV\nISBgcMW1hgN+fn5oaGiA1WrFtWvX4OvrC0AYIdTW1krH1dbWjqj2cTqdWL9+PZYsWYLs7GwAo7ct\nRLy9vTF16lR89tlnaG1tBc/zYFlW8bxiWwQGBsLlcqG9vR1ms/kmVx4efPrpp/jggw9w6NAh2Gw2\ndHR0YNu2bWhraxt1bQEIHgQA+Pr6Ijs7GydPnhzUd2RYTUkNRJtqJNJ7JJCVlYW//OUvAIB9+/ZJ\nbTB37ly8+65Q6rOiogImk0lyRUcCJSUliI2NxcMPPyxtG41t0dTUJGW6dHd3o6ysDLGxsUhPT8d7\n7wmy4PK2yMrKGrF6bRs3bsTBgwdRWlqKHTt2ID09Ha+88sqobIuuri50dAjV/To7O3HkyBHEx8cP\n6jsy7KRB+tKmGsls2rQJ5eXlaGlpgdVqxbp165CdnY0nn3wSV69eRUhICF599VUpMP7CCy/g8OHD\n0Ov1eOmll5CYOHzlwOUcP34cDz30EOLj48EwQnGh4uJiTJo0CRs2bBhVbXHu3Dls3rwZPM+D53ks\nWLAAjz32GC5fvoyNGzeitbUVCQkJ2L59O9Rq9ajRazt69Ch+/etf48033xyVbXH58mU88cQTYBgG\nLpcLeXl5WLNmDVpaWgbtHRl2BoNCoVAod4dhNSVFoVAolLsHNRgUCoVCGRDUYFAoFAplQFCDQaFQ\nKJQBQQ0GhUKhUAYENRgUCoVCGRDUYFCGNcuWLUNBQQEWLlyIxMREFBQUoKCgACUlJbd8rdWrVw9I\n7vq5555DRUXF7dzuLVFZWYm///3vQ/49FMpAoeswKCOCK1euoLCw8Ibqo6JUxHBh7969KCsrw44d\nO+72rVAoAIaZlhSFciuUlZVh+/btSE5ORmVlJdauXYumpibs3r1bKqizefNmpKWlAQBmz56NXbt2\nISoqCitWrEBKSgpOnDiB+vp6LFq0CBs2bAAArFixAo8//jgyMjLw9NNPw9vbG+fPn0ddXR1SU1Px\n0ksvARC0eZ555hk0NzcjPDwcLpcLWVlZWL58ueI+GxoasGnTJjQ3NwMAMjIysHr1arzxxhvo7OxE\nQUEB0tPTsXnzZpw4cQI7duxAV1cXAGD9+vWYNWsWqqursWLFCixatAjHjx+H3W7Hli1bkJqa+o20\nNWWUcCfFOiiUe4Wamhoybdo0xbaPPvqIjB8/npw6dUraJi8uU1VVRTIzM6XPs2bNIhcuXCCEEPLA\nAw+QTZs2EUIIaW1tJWlpaaSmpkbad/jwYUIIIU899RR56KGHiMPhIDabjcybN4+Ul5cTQgh57LHH\nyC9/+UtCCCGXL18mKSkpZM+ePR73/tZbb5Hnn39e+tza2koIIeSPf/wj2bhxo+Le8/PzSWNjIyGE\nkNraWjJr1izS3t5OLl26RMaOHUv2798vPXtmZiZxOp0Db0QK5SZQD4MyoomOjsaECROkzxcvXsRr\nr72G+vp6cByH+vp6tLS0wGKxeJw7f/58AIDRaERUVBSqq6sRGhrqcdz9998PlUp4lcaPH4/q6mqk\npaWhvLwcL774IgAgLCxM8mR6k5ycjN/97nd45ZVXMHXqVGRkZPR53PHjx1FTU4NVq1ZJgpQcx+Hy\n5cvw8vKCXq/HggULAADTp08Hx3G4ePEiYmJiBtpcFMoNoQaDMqIxGAyKz8XFxdiyZQtmz54Nnucx\nadIk2Gy2Ps/VarXS3yzLwuVy3dJxA62zMHnyZOzbtw8fffQR/vznP+Ott97Cb3/7W4/jCCFITEzE\nrl27PPZVV1d7bON5fkTVeqDcfYZPBJBCuQlkAPkb7e3tkjrpnj17+jUCg0FaWpokK33lyhUcPXq0\nz+Nqamrg7e2NBQsWYPPmzTh9+jQAodaFKGMOAKmpqaiqqsKxY8ekbSdPnpT+7urqwoEDBwAIJUoB\nIDIycnAfijKqoR4GZcQwkNF0SUkJ1qxZg+DgYKSnp8NoNPZ5fu9r9bfvRsf94Ac/wLPPPov9+/cj\nOjoaqampiu8TKSsrwzvvvAOO40AIwdatWwEAM2fOxNtvv438/HxMmzYNmzdvxhtvvIHt27ejra0N\nDocD4eHhePPNNwEAVqsVX375JYqKimC327Fjxw5wHHfTNqFQBgpNq6VQhgibzQa1Wg2WZVFXV4ei\noiLs3r0b4eHhg/5dYpbUkSNHBv3aFIoI9TAolCHiwoULeO6550AIAc/zKC4uHhJjQaF8U1APg0Kh\nUCgDgga9KRQKhTIgqMGgUCgUyoCgBoNCoVAoA4IaDAqFQqEMCGowKBQKhTIgqMGgUCgUyoD4f001\n1ZxdsABYAAAAAElFTkSuQmCC\n",
+ "text/plain": [
+ "\u003cmatplotlib.figure.Figure at 0x7f96f1241810\u003e"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "test_accuracy tf.Tensor(0.99, shape=(), dtype=float32)\n"
+ ]
+ },
+ {
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYwAAAEcCAYAAADUX4MJAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzsnXdgFGXex79TtiabZJNsGoGEBEihhRJAQBQQQaqABQue\n4h3HqYeK7ThFz0NRz8N2nHIoqKe+dyIKJ4KA0gQpofcaSO/JJluSrTPvH7PTtiSBJIIwn792Z6c8\n88zs83t+9SFYlmWhoKCgoKDQAuSVboCCgoKCwq8DRWAoKCgoKLQKRWAoKCgoKLQKRWAoKCgoKLQK\nRWAoKCgoKLQKRWAoKCgoKLQKRWAoKCgoKLQKRWAoXHfs378ft91225VuxjVPv379UFJScqWbodCO\nKAJDQWDUqFHo3bs36uvrZdunTJmCzMxMlJWVAQD+9Kc/ITMzE8eOHRP2KSoqQmZmpvB95syZWLVq\nlfB96dKlGD16NPr374+bb74Z8+bNAwBMnDgR/fv3R//+/ZGdnY0+ffqgX79+6N+/P5YtWxbQxiVL\nluDZZ59t030OHDgQ33///SUd869//Qtvv/028vLycNNNN7Xp+jz+fXStcejQISQnJ1/pZii0I/SV\nboDC1UVycjLWrVuH++67DwBw9uxZOJ1OEAQh7EMQBKKiovDOO+9g+fLlsu3BWL16NdauXYtPP/0U\nycnJqK2txZYtWwAA3333nbDfzJkzcfvtt2P69OltugeWZUO25XLZvn07nn76abjd7nY/99WK1+sF\nRVFXuhkKVxGKhqEgY8qUKVi9erXwffXq1Zg6dWrAflOnTsWZM2ewf//+Fs95/PhxDB8+XJhtxsTE\n4M477wy6b3OVanbs2IGlS5di/fr16NevH26//XYAnKB5++23cc899yAnJwclJSX45ptvMH78ePTv\n3x9jxozBl19+KZzHX0sYNWoUVqxYgcmTJyM3Nxfz5s2Dy+USfrdYLCgsLER2djZmz56NqqoqQQuq\nrq4Gy7JYtmwZxowZgyFDhuDJJ5+ExWIBALhcLjzzzDMYPHgwcnNzceedd6Kurg5vv/02Dhw4gIUL\nF6J///545ZVXgt7z448/juHDhyM3NxczZ87E+fPnhd+cTidef/11jBo1Crm5ubjvvvuEdu/fvx8z\nZsxAbm4uRo4ciTVr1gh9JdVqVq9ejXvvvVf4npmZiS+++AJjx47F2LFjAQCvvvoqbr75ZgwYMADT\np0+XPXOGYbB06VKMGTNG+L2yslI4V3FxsdAPb7zxBkaOHInhw4fjL3/5i9BWs9mMOXPmIDc3F4MH\nD8b9998f8h1QuLIoAkNBRt++fWG323HhwgUwDIMNGzZg8uTJAQO5VqvFnDlz8NZbb7XqnGvWrMHy\n5ctx/PhxMAxzWW278cYbMWfOHIwfPx6HDh0SBkEAWLt2LV555RUcPHgQiYmJiImJwbJly3Dw4EG8\n9tpreO2113Dq1Clhf38tYcOGDVixYgU2b96M06dPy4Tmzp07MWTIEGi1Wnz44YeIi4vDoUOHcPDg\nQZhMJnz66afYsmULvvjiC+zYsQMRERF4+eWXAXADss1mw44dO5CXl4eXX34ZGo0GTz75JAYMGIAF\nCxbg4MGDeOGFF4Le80033YQffvgBu3btQnZ2Np5++mnht9dffx0nT57El19+iby8PDzzzDMgCALl\n5eWYPXs2HnjgAezZswdr1qyRmQv98e+LLVu2YNWqVVi/fj0AoE+fPvj222+xb98+TJo0CU888YQw\n2K9YsQLr16/HRx99hAMHDmDRokXQarUB533zzTdRWFiIb7/9Fps2bUJlZSX++c9/AgA+/vhjJCQk\nYO/evdi1axeefPLJkG1VuLIoAkMhgClTpmDNmjX4+eefkZaWhri4uKD73XXXXSgvL8eOHTuaPd/k\nyZOxYMEC/Pzzz5g5cyaGDh0a1D/RFqZOnYr09HSQJAmapnHTTTcJGs3AgQMxbNiwZrWhBx54ALGx\nsYiIiMDIkSNlwmXbtm3N+i1WrlyJJ554AnFxcVCpVHj00UexceNGMAwDmqZRX1+PixcvgiAIZGdn\nIywsrNX3NW3aNOh0OuG8p0+fhs1mA8uy+Oabb/DCCy/AZDKBIAjk5ORApVJh7dq1GDZsGMaPHw+K\nohAZGdmswPDn97//PQwGA9RqNQBg0qRJiIiIAEmSePDBB+FyuXDx4kUAwKpVq/Dkk08iJSUFAJCR\nkYHIyEgAcm1x1apVmD9/PgwGA/R6PWbPni2YI2maRnV1NUpKSkBRFAYMGNDqtir8sig+DIUAJk+e\njPvvvx8lJSWYMmVKyP3UajUeeeQRvPvuu1i8eHGz55w4cSImTpwIr9eLH3/8EU899RR69uyJYcOG\ntUubExISZN+3b9+O999/HwUFBWAYBg6HAxkZGSGPj4mJET7rdDpUV1cD4Aa9Xbt2Yf78+SGPLSsr\nw2OPPQaSJIVjaJpGTU0NpkyZgoqKCsybNw9WqxWTJk3CvHnzWuUbYBgGb731FjZu3Aiz2QyCIEAQ\nBMxmM1wuF1wuFzp37hxwXHl5edDtrcW/L1esWIFVq1YJfWK322E2mwEAFRUVLV6rrq4OTU1NMt8U\nwzCCQHn44YexZMkSzJo1CwRB4M4778Ts2bMvu/0KHYeiYSgEkJSUhE6dOuGnn37Crbfe2uy+06ZN\ng9VqxQ8//NCqc1MUhbFjxyIjIwPnzp1rj+YCkJs/XC4XHn/8cfz2t7/F7t27sW/fPowYMaJZ/0go\njh07huTkZBiNxoDr8CQmJuLDDz9EXl4e8vLysG/fPhw+fBhxcXGgaRqPPvoo1q1bh//+97/Ytm2b\nYEpryXm+du1abN26FZ9++in279+PLVu2CPdgNBqh0WhQVFQUtD3BtgOAXq+Hw+EQvvNCQIq0Xfv3\n78dHH32E9957D/v27cO+ffsQHh4utCMhISHktXiMRiN0Oh2+++47oY/279+PAwcOAADCwsLw3HPP\n4ccff8TSpUvxySefYM+ePc2eU+HKoAgMhaAsWrQIn376qWCPDgVFUXjsscfw4Ycfhtxn9erV2L59\nO+x2O1iWxfbt25Gfn48+ffpccrtiY2NRWlra7ODvdrvhdrthNBpBkiS2b9+On3/++ZKvBXDmqBEj\nRgjfY2JiUF9fD5vNJmy7++678dZbbwlhx3V1ddi8eTMAYO/evTh79iwYhoFerwdN04J2ERsbKziF\ng2G326FWqxEREYHGxkYsXrxYGMwJgsC0adPw+uuvo6qqCgzD4PDhw3C73Zg0aRJ2796NDRs2wOv1\nor6+HqdPnwbAOaI3bdoEh8OBwsJCfP31183ev91uB03TiIqKgsvlwpIlS2C324Xf77zzTrz77rso\nLCwEAJw5cwYNDQ2yc/Baw6JFi1BXVwcAqKysxM6dO4U+5oWOXq8HRVFKdNZVSocKjD//+c8YOnQo\nJk2aFHKfV155BbfeeiumTJkisxsr/PJIZ5adO3dGz549g/7mz8SJExEXFxcQessTHh6OpUuXCtE8\nixcvxl/+8hf0798/5PVDMW7cOLAsi8GDB2PatGlBjwsLC8Pzzz+Pxx9/HIMGDcL69esxevTokOds\n7rrbt2+X+S/S0tIwYcIEjB49GoMGDUJ1dTV+85vfYPTo0Zg1axYGDBiAGTNm4OjRowCAmpoazJ07\nFwMGDMDEiRMxePBgTJ48GQDnN9mwYQMGDx6MV199NeDat99+OxITEzFixAhMnDgR/fr1k/3+3HPP\noUePHrjjjjswePBgLF68GCzLIjExEcuWLcOKFSswaNAgTJ06VRAYDz74IFQqFYYNG4b58+cH/Df9\n++LGG2/EjTfeiLFjx2L06NHQ6XQyk9VDDz2E2267Tbj3F154QdBgpOd6+umnkZKSgrvuugsDBw7E\nrFmzUFBQAAAoKCjAgw8+iH79+uGee+7Bfffdh9zc3JDPROHKQXTkinv79+9HWFgYnn32Waxduzbg\n9+3bt+OLL77AsmXLcOTIEbz66qtYuXJlRzVHQeGSqK2txe23396iU19B4XqhQzWMgQMHIiIiIuTv\nmzdvFmLp+/btC6vVipqamo5skoJCq7Farc06uxUUrjeuaJRUVVWVTL2Nj49HZWUlYmNjr2CrFBQ4\nUlNTkZqaeqWboaBw1XBFnd7BrGHXS9kFBQUFhV8bV1TDiI+PR0VFhfC9oqIiZJKYlI6oFdSefHpo\nFdad3QwtrcG/p79zWee468s/CJ+jdVFYOvm19mpeh1FurcLj618CAKy8+wMAwHu7V2Bn0T7EhcVg\nycTg5S8W/7wMe0sOAQA+mvI3RGgNwm9bLvyMpfs+DzjmnxNfgSksJmD71cxft76D41Vn0D2mK169\nhSugWGGtwlxfn/nzybS3oFfphO8sy+LulY8I33vG9cBLI6/erOjT1efx4hZ5fg7/Xizc9g6OVZ5B\nt+hULBrzHABgT/FBvLXrQ2hoDT6+/e+4d9UfAQDzhv4OQzrLAyQ2nNuGFQe/RDAitRH4cMob7X07\nbeJ8bQH+/CPXpi/vel82fj2z8VUU1otVfVfe/YHw/3/hprnok5CFnYX78N6eFQCA/9y5BBQpRpF9\neWwtvj65Puh1aZLGx1MXQ0Or2+U+OlxgNOdTHz16NL744guMHz8ehw8fRkRERKvMUQRBoLra2p7N\nbFdcDi8AwMN4L6udDU6L7HtaRGrI85hMhqumL87WivH4fJtcLq4v3J7gffG//O8FYQEAP58/jAh1\nOPIqDuGejGm4UFUa9FrPblyEV4Y+DzWlEra1R194GS/+e+YbDIjPQWZ090s+nmVZfHZqJeocZmhp\nLYYkDEBOXG8AgMvtey/cDKqrrWhwWvDS7tAD2+5zR0CRFI5UH8eMjGmosFfJfj9RdRZPr38FRk0U\nHup5L1Tt3Bdt5Xx5YGlzvk1a6Ll96grww8ndyDH1wt4CLrLM6XHi6e/FqLG3dn2Il4Y8ixJbGc6Y\nz+PuHrejrC60r7PBYUFpRS3UlBpWlw3fFHyLMUmjkBSeEPKYULi8bnx+aiVGdxmBlIjLS4Yss1Xg\nrYMfCN/nb3wD47uOQVZ0DwCAjhAnBUZNlOy5LdnzKSLVEahsFJ/9/gsnUeeoR4GlCNO7TUJhbVnI\na3sYD/acP4rsmAysPLsGjw6beVn3wNOhJqmnnnoKM2bMwMWLF3HzzTfj66+/xn//+1+hEBxfvmHM\nmDF48cUX8dJLwWdavzYogutWL+O9rONP18kT2rSUps1t+iWoaaoN2Eb6XjEWgRMHs6Memwq3Bpzj\n3UPLsLt8H87W56O2iYvbD6P1sv3s7kYUWppPGLscTtSexq7yffjH4dB5Jc3R4LJgb8UBnKu/gGM1\nJ/Hh8c9C7ru1eCfcjDvk7yW2Mrx/ZAV+LstDsbUUp+vOBuxTZC3FkZoTOGM+H+QMV5bqIO8Dj0Ed\nLnz+8Ni/Acjf+zJ7hWz/U3Vnsfz459hZugd1DjOsruDCMCEsHgBQ43tvVp9fh7ySw/jPmW8u6x4O\nVh3Bgaoj+Nv+f1zW8QDwwdGP0eRpEr5faCjEksMfCd/rnGbhc7TWCDfjEb43uhtRaC2Gw+sUtlXa\nq/Hxif/D1uKdsLhsqG2qAwECNCmf/9/ShQsHr2qsgcPjxPaSXZd9DzwdKjAWL16MnTt34vjx49i2\nbRumT5+OGTNm4O677xb2efHFF/HDDz/g22+/lcX9/5ohfeoiCxb1Ti6JqcRaJggQs6MeDU75C+/y\nulBm4/4kp83cH+d3vbjZgJdl4GE8KLaGnklcaRiWwa7yfQHbSV54sty9syyLQksxGJbB2gsbA/b3\nSIRsibUMlY3VoEkaTw98LGDfat+g0J6YnQ0t79QMlhADGQCwLF90kROeUnOTFL7PaiT3d67+Ak76\nBMYfc34XcMy6i5uEgabOYYa5qW33cbkwLIMiSwk8jAe7g7wP/D35TyCqG2tR66hDplGu1fGDoFSA\n1DTVweLiEief6Pd7YfsNibkYkjDAtw8nrEps/H+mddkDdRYHzFZxcFZToinnWM1JVNirUGGvwvGa\nUzhWcxIuryvYaQBwE8Yia0nISUFtkxkMy6CuyYyUiM5Qkyq4GTfOmy8AAAbG5+Dtm1+FilTJjiuQ\nTJRqHbWobqpFrC4a83OfELaPSx2N7lFpALixxeqyoT1Qakl1ABQh2hef//lVPNL3Ybx/ZDluSMzF\n/Vl34oVdiwAA/xz1N2G/5ce/wPHaU3h24B9xvv4iwlVh6GzoBIAbbL8v2IwNBZvx/KB5l6VadzTH\na06h2MqZj6QaEW+q5U2TJ2pP44OjHyNKEykIUym8YAGANfmcXTZBHyczPWkpDRxeJ2qbmcFeLv7m\nwEuluT9mo2+Wyc8WvSGq9kZrolDrMMs0ttXn1wHgZtAmnei7yTB2wxnzeRRZS/Fz6V7c3HkY3j20\nDGEaHZ7tP7dN93IplFTb8PW2fPTKteObC2sQ7e6GelXg831p9+v4x8jX4fKKg2iMNhr5DVwxwz6m\nnrjQUACXb5A1aiJR3VSLY9UnhP2rm2phddlAExQiNWLYvppSIUYXDQAotVRj87bDqDJxpiuPn7Zf\nU98EmiYRFS7X3p9+n5uFr/jTKADA7hOioFp69JOA+xmfegsmpAUvn/Ppyf/iQNWRoL8BwIu7X8Or\nw56Hh/UiRmtEbVMdah11WHKE0z6KS904qKqGs5EGqRX7a2fZXuFzibUcNrcdnQ2doKXFe1GTavyY\nVw5oASfjgtXdPuZJpTRIB0CT8rIGxVbOlhtsxsVzvJbLcr/QUAizox6JYfGCY4thGRyrOQkAqG66\nOvNUpIObRjIrc/pmYIxvdl3la38wYQFwNld/0iJTZbMs3uzQnMnjcuEFRqSac7x7GQb/XH0Mu49X\nNHeYgKUZgcELE6uLqzZ74Dw3+52UNhYP97offWI5DZsiKURro4Ka+LKM3WW+Cj2tw9RuEwBwfepl\nvKhpqkVhfQnqHOaA49uTH/cX46utnCns3a+O4Eh+LfZc4LSgGoZ7529LvSXgOLu7USYwaJIW3odY\nXQxujRDXw6B8GkaDRHOrbaqD1WVFuDoc63aJfhI1qUa4iqsEfKqkCscuVgmze/++fHbpbsxbIi8X\n4/aIAtza6EJxlQ2H85t/7rVN9SF/a05Y8Jgd3H2HqcKgIlWwuxuF30rKXVjyzTGA4cYBLREecPze\n4uMAuH7TkKLAyC9qxLHz3LmdXmez7+WloGgYbeRg1VEcrjqG32TPEAZ4f0e/nhZNDwv3/F34vDZ/\nA1SUGuNSRwnbVp37FgD3AvCaitnRgFJbOYDmB6S28lPJLhRbS3FfVvDFjZpD2q4GlxX/Pvkl7s+6\nEw4PN5vmBYaOEmtTqSk1Mo3dcbRGnD1KNQyeblFdZfbZWF00SqylMpPNt/kbcOrAGUxOvQ1ZMT0u\nuf08dc564R7mb34LpohwHL+QjANnqtE9ORKxUeKzLKm24b+bz+GeW3qgUyw3UAWzrf/f6VWYkTEN\nNjdXg8nubsT8D3fDbKgHHQ/0NfVCYlg8DldxS94SBIlYXUxQv0QPYzpUkr6w2LzYcsgFJHOai9Ut\nPofTdecwNGlQ0Ps8U2TGpn3F+O3EbOg0wYeB7y5sxL7yY4jWRYAFgyaPAwzLQEfr0MXQCet/5Aaw\niUNTUW/jJgaNbANAAKSGKw+SY+qF7wt+lJ3X4rLCzXD7EwyNysYqwTz57bZSnDsHqLOiQBnqg/rv\nCswVsLhsiNfHYceuKuh81dBVlApNjZxK64EThEo0LTV6mvDmyjzcOSIT8UbRH3ax3IIIvRoxkVqY\nbU7Qnc4BXhrnSnrjbHE9QDbvh9xbuR+RWgOmpN+GAksRNhRsxm+y72l1octKO1f8sbLGBcbrN3+n\n5JOnRisF0k9mXLRdAEFx/4ldR8VCkkXlTQDDTdzqbHYcbWgfc7aiYbSRTYVbcaDqCCokUQz+g56W\nFgdJ6X4bCrdg7YUNcHsDbZycwOAeD6+uA8EHpPbiy7NrsKt8nzC4Xwr87FnnE457Kw6g2FoKh5cb\nOPg+kTr0RiYPl5magOAaRh9TT9ASM5+aVCMxLB6ltjLBhry1ZCeKGkqxvfTyigzy1DvEGaOFqEC+\n9TzUXblZ3PnSBtTbnCiv5Qb+XccqcLLAjAUf7QXjGyD4fsiNF8NAfy7LQ4W9SuhXFiyqnZXCgPDN\n1iJsPViCKenj0cXQCQ9l34P+cYGFGRl7BGpKw1FRLXGA1rpQWcP1gcPjkJnE/IMnpCz+8jAOnavB\njqPlQX93ed34vmAzapxVOFt/HufqL6DEVoYyewXyGy5ia8lOYd+iSiu8DHf/UgcuADQ1EtDV9ZJt\nyztbJGgYXpf8+Z8r4N4X98VeiNMmYKZk8kKAAOuhcc56Gm7GDT0ihdk3AOw8VIn3vuQ09WqrBYRK\n7l84XVGK91cfR53ER7Hw0/145oNd8DIMzBYHVJ3yoepyBkv/dxyb9hWDoFoOXNlUuBV/+7+DePvA\nv3Cs5hS2l+zCgdLgdfFU9Wny+63mBvIT5y0wN8jffY1Fvi/rDRTsfPs276rDFz+Iz7vW7BH2P3i+\nAjtPFrR4H61BERhtwOayo8TniJaqvP720pbIbygI2BatjRKcn1KkA0JNUy12Fe2XRWC0B8EG7Zaw\n+GykRk2kuM1lFTQMN+NBdWOtYCKY3fs3mJw+DoTfK9jolt/LY31/Cx2tlcWdqygaGdHd4WG9OF/P\nCVM+qv2c+cJlR6fxbfaHiqoBaBfqbS7MW/Iznv9wLxiGhcPF9xOLjScOg2VZ4fgp6eOQFpkqnKPc\nXik/Z0QNCJ/AOHDSjM82nQXp0YM8PwKEMxLDOw1BnF4eYu48MRSfb8jHwk8PCNusNq8wMDS6HbL2\nHyo/jcIKuU/m+MVavL3yCDxeboCvqGvEiYI6vPb5ASz6/AAaHW7YHW68vqZ15epBMDhUykdvMSDU\n8uf3zdZiNBWnwmsxCtvWHzgnmCrh8csP8H1nHeGYYnoAbnsYWA93f16nFt4GsU+OHyUBVsxnqKl3\ng/VwAsjqsoOK5sxJjJObxNCmUphtTahsaAARJjeJ/ufHc9h58ajYDC8n3DUarp+cpwI1NakScbqo\nHh7fu731cAE+3x184mI52w2sS9SaTpRx1YpZLw2WEf8LjsMjMH2oXNDCK/4HXOf7wms2Cd+rKuX/\nI9ZLi8KU9Mq0rbagCIw2cMZ8Toj2kJpH/DWMlgawiw2FAdvi9LEy5zn/WWr6ef/Ix3hn93JsKNhy\n6Y1vBvdlCAyrywY1qUKYSlT36xz1cHjEtRe+OP2VIDDUPp8EyZfr9g35/iY3qSOPhyZpZBi7AeCE\nLcuywozV4XWi1B581hwMl9uLd746ggNnquD2umXhi1JUyeewcqtoIqq3OYWZKt35DL6r+g92leeh\nys6FOIarw2W+nHJflE+vmCzuvg31osnB98d+e+VhnCo04/3VnGmqpyEnRKvFQdLjJgCfwLA4GoX+\nI1gKDOXEe9+JA1d5rR1vfXkExy6Ik5tth0qx+L+Hca6kAedLGlBQYcWWAyUoaQxddl2EharLKexo\n+hpUdDkItQMEKY6iLEPgTIEVTU4PwrSiJkGonGh0OcEyZOCsmSXRN51z6m8+UIxV2/KF+2UdeoS7\nk4RdGUuMrC/AUABDgWUIUFE1oBO4/xVjjQIA0PFFIDudwr/2rYIme7dMuG07fwQHveuE7zf04drg\nZrn3ShPENEawcl8ly3LDqdnuAGkI5T8i4a0XB/oGt28/6QAPYGhGV4zom4S//eEGDO7MCY5ErZgH\nwro1gEU8D+uUR9yxXhpd431CmvSgU6f2GeoVH0YbkKr8cg1DPuC2NAC7/MLuJqeNQ2pEF5lpKEoT\nGRB/bvY5NWsvMby0zuJAk9OD+Gg9ln17AgMz4zAoK17S3tC5AaGwumwwqA0yX0NNUy0cXid0tA5N\nniY0eRxCX/D78WY3kiDhZb0B0RxScx6PiqCFaJgGpwVuxi0L07S57LL9nS4viqqs6J4cFXCus8X1\nOJpfi6P5tfjDnWkBv/P4z9BqLQ7UWZzQqCgQJs7xes58AYXWYrD2CLhdrExgFPpsyEZtFDcrptyI\nCKdg81LgB72Saq7dZqsTRZVWrF9LgYwcANajBtwhMnUZCnFGPSxeCnZXk/B+UA4jPLoaWJkGeBkG\nBEHgpRV5Ie+Px9bkxs5j5SBi7MF3YAmA8PU1yYCK5e6LNJi5dkrQUlo4fPemogkIUwfahXp7I8BQ\niNLr4K/TZaYYcSS/FicKuPdbN4h7HxmrEXfn3gzSmAG3U4UP8+SJjNyASwBeFUCK5ijGFgXEcpMI\nOqEIrEsDggBuH52AtIiu6GQKw6INX0F6x+kpWuw+CpiiVagD0D0pBv4GPoIQ37pbB3XCT17f5Idk\nQGgb4Q9xhsuLcBdmgfWooEq6CELLXVWn0iI+NhJl7lqwXhI39U0GSRCIjdRhZu8pGGTuiazoHnhs\nK6ddPjV9ENSMAbVkHzRaaHgjo/HFD2KeTlJUBIb2TMLXZgIE7YGZKQ/QWC8HRcNoA6fN56H1OXGr\nmzFJ8Xb8UHhZr0ybGBDPzSwJyexJRakQrgpDfkOBEDHFz3qks/L8sgb87f8Owtoot9+6PQxOFNSB\nZVk8/+FeLFieh1XHf8AR51Ys/d8J2b5v7l+CmqZabC3eiY+OfSY48BiWxbJvT2DHEbkDbfme71Hv\nbADrVqOkUvyjbCneAZvbjnh9LGhGhxJbGTYUcAsL8b4LXnDwUVBVjfIoMA2pxrELtWh0SKNqVIjw\nJX5ZXVaYG+V/Tj589Wh+LQ6fr8Fnm87gtc8P4uDZwNXl+EEapAcf5zeTrOfngDxX0oCSahv0nYtA\n0Nxv52vKQBAsvA0xOFVYL3umpyqLxPtmKNAqBjodEKbWYdygLrJzO1xefLnlPAACTIMJrD0SrEuH\nOGNg3gbLUOjVNRrw0mh0O1Bj40wtTWYuyotVNaKm3oFNecWCGQqkB+qMfSCjKqHJ3g1N9m6A4vq3\nuMoGc/hR0DEVINjA4UHmyyUYkL5XlI4vgiFFrpWEq8W1yylKslaK2gm7ywmCodApNjDyJ6NLoGAH\nOIGRlhiNiARjAAAgAElEQVSJgQn9kJucHbQv/jxzAGjItQHGHin7Tqg54b+h9kswhgosPbUU9uij\nsn2i4zxIufEw6lScmJh2Y2AghXSScnP/RMBnUgozNoEgGTBN4v3H601gHb57ZSl4azlNiX93/jAp\nB0nRvhBhhkZEuCh8KZJCdkyGrJxISkws0pOMGJTQHzf36INeadGytv12Qm9kpxpBsiqQYRa4GBcy\njZcfDMKjaBiXSaO7CXUOM7JjMpBffxE2yaDtZeWDC2/HDwXDMKBICl4vJ2j4wZAgCFAEBS/rBU1Q\n6GFMx4GqIzhafQK9Y7MFDUQ6K3/3q6OwNbmxYW8RGJZFo8ODh8ZnYdW2fPywvxgP3pYJp69ExY7a\nzaDjAHexfK3remcDPjr2GYp9SU9uxg01pUZdgwN7TlZiz8lK3NiXe+Evlluwr/IQKANQcSEKlMEM\nSv7uIpyKhMtVC1KiLPCCYnzXMSi3V2JK+m34+4F/BvRNQVkTlnzFLf6j85mRfz5aCcvFElAEBYvL\nhq+2nwb04P6wJINGdxOKq2x45yt5WOPhczXo38Mk21ZYyfUdGVkDgg6tWRF+AoMzlbBwxB4TttU5\n6kCouBntP1cfQ+oQ8bmwak4wqUgVWC8FimLg8Dpg0IRhZGYnbMiTZ62fKgw0afxhSi98vP4Uiqok\nZjuGwg09E7DnqApNHgd2nSkEogDGZgRwEaS2EQs/3Y9GJ9f+4X0ScaTuEDyRtaAixUkOGV4PpsGE\nw/kVUKVyiWNhap0Q2QUAzrP9oe4mlnEB6QVJEuCnRy6dPAQ1QiMOmCBYIXeOCrMAlBfGsHDQtDgI\nTus6Fd179hKimAgAzz8wEIdrSWw8fRCM1YiYSO4loqlAYXZzn87o1ikSUfnhqPVYfecgsHTuVHx9\nmsDO0j0BCYPBcisA4GTtaVQ5xfsxGcICdyLEc4XrKcGf4lJx/co0GkDquP5TU2poVBSanL4EVpdc\nc9artMLEjPXSiAoLrlHOyJiGgoYiIbhEaF+UDkOy40HrR8GlL0fn6BiQBIkovR5mJzd5TDYkBj3n\npaBoGJdJjYN7KUy6GE71lti+Q2kYj/R9OOi5vCwjMz9Js0t5kw1FUrgnk1thjtco3D7BJHWEe3U1\nAOVG3qlKbMwrxo6j5SivteOH/dzs78CZIGs4awKd5sU2UYvgHZR86CShbsLn+3/A01+sxMJ/54HU\nNIFx6OCtTBUcd6zEQWfwJskcdgDw2ffn8b+dF2FQh+OJ/nNkdXqkM7MlXwVGm1TWObExrxiMS40G\npxVHL3CmCcb3JzxfWY2SahtIQ60wcwaAXQUn8PK/d+F/h/bjpf9swtnqYpwvqYdeQ6Nr99AZuwBA\n6KwgdFb86T4x+onQy40phIq7Vve4JBAEUFYbGAJ9rsgGlqHgVdlgc9uhpTUwRekw766+sv3UKhK3\n5op9Mu+uvkhJMODpe/phwW8GijsyFExROhjDwgDKAw/BvWs9Yrr62miBQ10F0liJxFQr7hndHdOH\nB840CQ2npZU7xbwGqbAYn3oL0BAvP4bywItAcyuvGetVOrzwwEB0T46EXis+f0JnA6FyIUytFfx9\nKRGdMbrrDegSb4BGTeGlB3Px1mPDkJYUgYmZN8KdnwMVTTdbdHRoTy7RNT5STOZ7ov8cJMVGYEbG\nVDwuyQpviQq7/H+iJpsv3udm3IAvYoklfNFwjWIRTQ2lxh+n90Fmlyj8/ZGhePauQbJwey2lgdM3\nhoSrtVCrgi9Re2OnIZiZfVdAP5AEgdmTe2LWkHGY0+chIWBGahb1FzKXgyIwLhPeyR2ri4GW1sq0\nCH8fRpPP8SsNDZXiZtwhI5NI3zE0QUNLaUGTtCAgeGd6k8cBt9eNqsZqoNtuaHruQq1FbM/zH4qZ\noaU1gYMYoWls1jH//PJdyDtViTordx+qlJPYbfkBTYn7QcWWglA7wTp9zm7ejMGSoF2cKWDrT06w\njPzezxZZsX6P6OyXRoSxMnt9kAHCJ5S8Tq64HENyAzU/ayuqqcOpmnPQZO2DOp3TMgi9BZqsPFQl\nr8Em80rUxP+Id4/9A7UWJ/p2i4FHVxV4HR80qwFBstD2/hlpSRHolRaNznHhmDou0HRCgMC8qUPx\ntzlDkRyvD/j91MUGwUkNQEg0M0nMTS88MBB/fXgwpgzvKmwz6Lk+Cdep0DVRHBBZhoRBr4LJYABB\nMoiOYaGjtHjzkVugp/WgDPXQZOVB0/0Q6uN+RoOnFlpt4N9+UA6n1VKRokmwn6m35MYIvP77IaBI\n8XmQ+sCIMpqkkeoT/lGaSKQlRWD+/QNkznAeNaUS3nv//0ZKggGRvixstYrCX2cNwt/m3CDbZ/bk\nbAzKEqtb82ZNad0xaR5H18guCKP1AXXJgiEt9gdAFqUXDIfXGRCCKzWFqSk1uiZG4Nl7+yM6QovM\nFCNidaIqrqXFSWdStNyE1hakk8/2qEmnCIzLhHdyx2qjoaU1Mj8FP2uK13PmD16YSF+6EZ1uwH2Z\ndwCAEBabFJaAv94gX+GNn0jQJAWCIGBQhcPisoJhGVk01v78Ylh9zl5S24RQtXPqfILEaBBfHlLT\nhLe/OhR0fwCwO51Y+r8TKKvhzk/oxJknGc7lLagZ32zKN5jr1DSeGfQHOI4PBevSyQZJAABLwu1h\nsOVgCfJLG3D8oiQT1z/U0h+GxLBeCWDdajDwAj6HNC8wqiwWnK7jIpqoKG4AJOhQGgSLzO5aVDXV\nIDs6A0/0myMrvQEAsWGSPzDBYN5dOXh51iCwFPfMe0WJORNRmkioSBoxkVoYwoJYfBlKJjzvzZwO\nAIiJ4NquVpFIS4pAXJQOOg0tPCfp8/I/H0EQQnCAxWOGQcMN/iZ9YPn3Wkd9UBOph+ImEmRELQiW\nwpP9/4D7/RI4Y6N0MvlNRnGz8Elp44RtYbQeD/e6Hw/1vBeT08XtwXJ7VKRKeIcpsnnreHJcuCBA\neIZkJ2DOFDH0lNdspIUNpUETNEnjqQGPYG6/2SGvw9+LVLtqDXyFgM7hSfhN9gzM7vUg+sSLpt5g\nUVaxkvdMS2vh9D0XLRUY6HG5SDWMYAEkl4oiMC4TqYaho7RwMx5htsSbpLr5in/xAkHqBM2K7iGU\nz270aSBJ4QmI0Ynx6oCYNU6RFEqrbdCSeljdNhw6L58B/XvzMdQ1ijM+QheoScy81WeKIBhQ3UWt\nQ9XlDC7EfBX6Zn3Zrkfya+Efa0+buPpREwdk4ZHbewl2XRVNITk6Gl189bBovwJqg3pw9tTPN53F\nq58dwFtfiv4Gf23En37d45FkCuNCCwFouh/mjvMJDIZwwUKKobXanK3QxAQvI9I9RY+vyrnaPT1j\nMtHdmCYk4fFInbcurwuFlmK8lvcOCiycmW9wsigwpLNGKkgeDRhKMM+pSBpRvrwVmiLx6u8G4405\nQ2W7L3x4EF56MBcRIWzaN/l8SfzskQULg4oT3v6CDwDeP7IcR6qPB2yvcdQBKgdIvQ3xqmR0i+oK\nLa0Vssr5d1caiBEex/lZhiQOEO4jTKWHURuFgfE5gvYkPV4adq2mVMJ26cDWVuQCQz5Qx4fFoVN4\naFv+gLi+IX9rDkFgGJIxKKE/+sZlY+508VyaICYtqcBQS/4f0j5qKxpFw7g64J3cERoDNL6Xklcp\necGh80l0fjtNUvhjzu/QP64PsmMyBDOMwydQgtlJGZ+mQBM0FizPQ0kZZ756f+0B2X5u1oll6w4L\n3wltI/qkiy/kotlDcHM/bvAmdFbYabmDMpjJQPyNExiFFVYx1r5JvlZ7H1M2BmSYkBjLvex8WfPo\nCK5vKFY+gzRFBnEi+khLMGJk8nCMTBgtbJtwQ4rweVBWHBKM+oAY/pRon3mCdsvMJYTaCZguIhgT\nRkcJs9x+vnUr/GfDpETQO70u/OPwRyixlQlVhWO0opCPUIt263sypiMrugf6SbK2WS8lxNtThLz9\niTFhiPQTDHqtCikJBoRiYAbnV5AODHzQRLTWGPQYvt1SKu1V0Edxs+qsBLGvn+z/B/SOzcJNyUMD\njnEwTYhQGxCliRRm0KHs5A9k340exm6YkTFN2Kajdbg/6y5kRffAjIypIe+xJWb1vBc5pl5CUU6D\n5BkEGyQJgsCtKSODnitSI+/rgfE5mJJ+GwBgbs5spEWmIi0yBX/M+R2eGvCosB+fQBoqdDU2iPDO\nMfVC5/AkDO80BARB4MGe9yI7JkOoDdYe6CTmN52iYVw5mnxCQEdphZeSV/X5AYhXLfnkNYqgkBnd\nHQ/3uh80SYvJeE7O4Rh0VSyJhgFAmFUTanmoLkEygCTKJyNNh8nDuoKMqgQVUwozSsTw2yDhkoHX\nFWeS94xJ40IojRWCOcpIibbjGzvdgISweBAEgbQk7g/HC0OSt3n7JTnRFPc9XKfC/Pv7y+oZdUuM\nxh09JmNaFlcFlCCA6Tely47vEm+QtREApt+UBS2lBaltBEF5oWODD5hSLG4uDHV81zFC5VP/SBq3\npIS10+uUZdbTBCUTEtI/ZYzOiMdyfitoWQA4k5RPw6DIdvj7+bpAWpCQHzCpED4zf/qaeoEFi/65\nnKAMV4uDTEpEZ8zp81DIwUa89+ZLtncKT8Tj/WbLZvcGdThiddF4LOe3goZyOQyIz8Hvej8gvHMR\nEg3Df40IninptwlFLKWoSJXMn3JH98mCcMmI7oanBjyCpwY8iszo7kiLTMFknwmLD3XPjA4euhps\nMa6UiM7406AncI9PiCaExeHRvg/LNKS2EtGMtnU5KALjMmBYFnU2K1QkDYqkBNugU9AwuLwKq110\nSgOBVWzF4oKcH0BH63DgTBVe/mSfkHfAD16sLymIL33gLzCmjOgiC/3M7haGiCgvND0OQZ1+DEsO\nf4Q3+UVgWiioBshfrthoNdSpp6DpfhiaFC45qHuMOAuVvuB8e/k/r94nCFS0/FUb2b8TeqfF4Jl7\n+qF7chT+8lCu8JtKkgX+tzk3YPGjw2THEgAXXukn+LQqDQzqMCHOPkmdgpao8/W9NGLFX8OQJlZW\nNsqjZwxqg6ySribIn1IqUITkMoQOgmgNyeGcKSpGy5nApNopb0LrFtU18MAg9PVVyT1v4TSP5gaW\n3Ph+su/8s/f6+kztZ3r0R9rPsn5pR1o74PJtkZrZCIKAXmIS0rcQWcQLaqvbBj2tQye/pQf4SaNs\n0vALIu2LYH6US6XD8zB++uknLFq0CCzLYvr06Zg9W+5wKisrw5///GfU1dUhKioKb775JuLjAyX/\n1cSP+0tQZbFCreW6T+d7ELxg8LAeUASFH/LKoE4VTVLSGd+u4+X4Pu8ikCqag1irEZuPl6CwworC\nCiuyUqMFH4aHlwW8fd8vL8Bk1GDUoDjsKOOcvRaXFS5G7tw0O+u5UhwSgTGmy83IieuFN/cvke2r\nV+kER76bcYMy+ZKytJypZ2DXFOT5cp2ksxh+sOVLfky7KR3WRjf0XapwSJKQHqFX40lJKKlJUgVW\nWpBQWh2Wh5//d4oxQOrJoQgK6ZFdhSTKPp07o1Nxb/zkDlz1LtPYHafN54QS4M0JDGnme4lNXnbE\noA6Xze51QRyW/KAOcFFNvJ+nJUdvczzZfw6qm+oEE4ha1gbufcyM7o75uU9AS2tAgMCLu18Peq7U\nSC5xsMHF2eGbs3XPyJwGFiz2VnAmUX7QZ4Xn3rwQlPazQdV+M2kprRVEYT5tiCAILLxhvtB2Pa2D\nxWWFhlK3GB0lnSwYg9R/e3HIM/CynhbP01FI+yJYbbpLpUM1DIZhsHDhQixfvhzfffcd1q1bh/z8\nfNk+b7zxBqZOnYpvv/0Wjz76KBYvXhzibFcPRZVWEJQXXrevTr1Pw3j74AdwMx54GS9nw/dFDPE+\nDelL8/X2CyitFjOUWYbEN9+bcbqIm/Gabb6y4L7h8WKZzXcObpDhtQn+pf/3qS+xs2y3cD6ryxa0\nxIe622GAFAdEkiCRGtElYD+pw1ZaD4onQmLrldqMeQHHx4lHhqkx944+0Gha/6q1NEvlX/zBWfLZ\nnNPrlKn+iYY4GLVRQvE6KXwJdEFgSEwp/jMx6aBQ5icwItThsnLjwSJRpNFKM0ZmIj6aO39bNAwt\nrUVng1hXSaphSEMpkw1JiNXFCKVUghGtNcp8D81F06hIWgjmAAI1DLKZPAlA/h/oMA2jlYJI77Pv\n0yQNozZK8F/w70Jr8hakzz7Y/URqDCF9Sb8E7WneAjpYYBw9ehQpKSno1KkTVCoVJkyYgM2bN8v2\nyc/Px5AhQwAAgwcPDvj9aoQiCYASywdLVfgfj52Gm/Fwhcj8on1on5OTYVgu81Zig2cdepmdf2Ne\nMSrqGsH4ykY3WDkBMaA7p311T+FeBOlAJ7W9W11W2SI1PISxXKZhhKobJZ2NVNgDcxSkL2JQDcPv\n1Wrt+gAAoAoRMTM3ZzZyTL3RO5YrC6Gm5YJlYFIf9I7NQnZMBrJjMpAemYKMzlEy53jPmExMShsr\nzKJ5k5Q0MmV27wdk5324l7igD78uSW58P6REdMaQxFxZXwWbnUsHkv7dEpAcz5mMyHacdUo1jFCm\nh9vTxwfdriJpmKQ5AS2YLqSDJP8eiBpG64eU9h7MeCiSwm2po3F/ZvPruuTG90NqRBdM6ipfMY+f\n4Blb4VeRTiY66n7agqGdhXKHmqQqKyuRmCg6ueLj43Hs2DHZPpmZmdi0aRNmzpyJTZs2obGxEQ0N\nDYiMbL/klUuhsrEalfYq9DGJ64t7GS/2Vx5GuDoMMVojKMpXh943EDklTtGvd5xBeGYTAEJWrhgA\nVqw7DQpqJETr4XR5uX1YX66Fnz2+uMqGPy/bA20uCwJiJcy4yDDADnTtpENBCaBX64FG0dbDVUoN\nw4WGQlQ1BWZ1A8DtNyfj+1IuoopfT0JF0rIiidI/fqldHlEFyGdx4ZLPjJ8PQ9h+CQIjlIaREd0N\nGdHdhO/SGfp9mXdCRamgpbV4VJJRn5KgRVykATUOTksalzoKaZGp2F/J3b/ZGejDSDYkYU6fB4Wy\nEfF6E+7Pugufn1opmLumpN/GFRL0I9jsXNoXakoFb4hktbYgHbjUIQTumJSbsbFwK5o8TQhXhcly\nDWJ1MSjyLbHbUjQNHWRW7b0MgRGh6RgNAwAmpo1tcZ+smB5BF9sq9i1Z0NmQ3OI51NTVLTDaW4vr\nUIHRmlnls88+i4ULF2L16tUYOHAg4uPjQVEt/5FMpvbriItlDVjwr134y+9uwF/3vAkA+GjKm3j1\no0PomhSJrr3N+PepL4X9E2unAQkAvDRMJgN6etOxmq98Tbnh8no4E4Ff2Ofhs3UB0UJC9U82uCrP\na/jJJgOevnsk8puOA2UAVNwf1D8qJVytR6zeCKvLhnUXNwU9Z2Q0CXBjA4Z0zYHJZICaVsPtEgXG\n6G7D8PmRbwAAlY3ytRwM6jAkxEchOSIRJZZypCclQqviBpnBKX1wpPo4hnfNlT2jwal9sK/yINdF\nJNXs84uOCm/V842yiKG5Rl+YbrDjjPow1Dg44RkfEwWT0YB4t9xM0DnehCideGysVwwbNpkMiHOI\nExiKINEtuVPQwTEhxhi0DSZ9NKob65CcEIvhzoE4Xnsao7oNbbf3ONYptjc+Jkpotz86lYYTGBq9\nIDBMJgO6xCTiYBXnlEoyRcMUEbpdsW7xt86mOJhMBtzSbTjWnNqIYWn9Wrwnoy4S5qYGdE1MANke\nkWKt4FL6+Zb04fghfwduzRwGU2zzx0nfk0RjbLuOS+1BNMNpzjG64O/lpdKhAiMhIQFlZWJNosrK\nSsTFxcn2iYuLwz/+wUXvNDY2YtOmTQgPb1lSV1e338pzH6w6ggabC/9ceRjwBTOcLazCyYt1OHmx\nDr0ZefnwixW10CYAjJdCdbUViVQyRnW+EVuKd3DVJwkGHg9Au4xgGULMcfDTIoZkx+MwSwBgYdBp\nkZMdj9QEg69SqRwtTUNPE2iycyakeht3/2F+AiNMFY4Hs+7DS7tfR4MzeB9VmH2z5LTb0FWTjupq\nK2i/V2FI9GAQWTQ+O7Uy4DxhqnBUV1vxZM4f0ORxwFrvhhVcu/oY+uL5QQlICIuTPaNMfRaeHzQP\nakoFHa1r9vlZrc5WPd8muyjgGm3c9YMdR7HiLNBmcaPaY4XTLndsN1oYuG3isVZJaZXqaiscdtGM\np6f1qK0JngnssHmDtuFPA59Ak8eBhjoHeoX3xoLBTyFeH9du77HDJravyeoB4oP3Bf+c9ZQerw9/\nEQC3n54V/3N2iwfVId4dALBbRTOmysU9y9EJI9Ensg/i1aYW7+n53Hlwel2orb20bOrLxWQyXFI/\nT0geh2GmoTCyMS0eJ+0L0qVq13GpvXh12PPQUGpUV1vbLDQ6VLz37t0bRUVFKC0thcvlwrp16zB6\n9GjZPmazWdBE/vWvf2H69Okd2aQAVp37FsWxawCwcHvFQaTW0gQirB7agRtxplae9KXpxS1KI3Wm\npho4xzHd6TwI2g2vh0CSMdJXNZRH1CLuvaU7Fxnk29TFFInfT+6JsYO6IJjfUKfmrsWbA/jIK38N\nI0IVjhitUWai8IefWXaWhPr5JywRBCH7HRAzhw2+DF4trQ0wyxAEgaTwhIDZN789VhfTYiZr0Azp\nFvZrLgpFaibizV3+Ic6qFiKWpH6BULkG/tfy3873FUEQQt5Ke6FqhQ8DENdmMagNMKjDBTNKrFZS\npqKF0hRSk5RRw90TSZBCKZyW0NG6NuVddDQqShW0rEowpObTjnLit5UoTWS7FB4EOlhgUBSFBQsW\nYNasWZg4cSImTJiA9PR0vPfee9i6dSsAIC8vD+PGjcO4ceNQV1eHOXPmdGSTAthavBNeqgmg3Sis\nEGcHtRY7VMnnQJAsqCi5L4CvYc94SWw7VAqHywPGJzxI38IpnppOiI7QyipW8lnPE25IwS0DO0Or\nocAHiUoHsHf+OBzvPX6j/JqEPHafTxL0H7wM6nAQBCFzRPeKycSwJHGJSb54oXSQeajnvUJsP480\n8oYAgUlpY9E9Kg03JcvzItqb1trBpWGpzfkDpEEJ/D13MSRjeNJgDEro36rMWqmturnY/PYov3A5\nSNsXyocBiD4rg0qeac9nIhMgWizTIRWuVypc9GohKTwBA+L6IsfUC10jW877+bXT4XkYI0aMwIgR\nI2Tb5s6dK3weO3Ysxo5t2UHV4RByE0V1gz1gm6c2EXSMGFZJ6uz498YzWLPjAqyohdZXB411q+Gt\nTEF0sgasVRxA/v7IMFSZGxEbyQ04Xi8raBjSMhF8ZdL3Hr8Rz+3eAACIjfCtSyxoGJwTV6+WD168\nw9KgNqDWFzI6NGkw+pp6IlITifUXf8AZM2fykg4ycXoTHup5DxbuFcOapb+rKBUGxOcIizt1JC3F\n8vNIhWxzA5c0N4IXgiRB4p7M0Nqs/9xfrmGE1pDao8Db5SAV7s0N+E6GExj+CYZGbSQogoKaUrWo\n+VxKAMO1Dk3SmNXrvivdjF8MJdPbB+GX/Xy+3MyV25DANsp9K14z54+xNLoBiXmKcHEDSpIpDGlx\n8toycUa9UC4jK9WIYBoGT7hOkhRk8C0cwwsMn4bhb97hB05pxAavNvsPJP61q/gQPF7TkIdqtl9x\nuFBkGrkcigR9XAt7ckgTIZvTMKTFA1syPfGYfAlx6ZFctrQ0ciiYhsGbWH6JfgpGa58VHzLqXwyP\nJEikRCQjTteyWSnKV0KlV0zm5TRV4VeMsuIej59wKKpqgLqrXIjwdZwAgLFFwlstht2xXvEP27dL\nF4wZOBApCQYYk3pg2fFdQS+ZnhQJgltMLmTNGx6+fEGAhqGSz2j5QTRYPR3/gcTfzxGm0uOVoX8W\nqoxKZ9XBqm22N4/0nYUGl6XViU5SgdFc1jQ/6ANotd8gShOJhUPnC3ZpaRhxMB/Gi0OegcPjaJds\n2stB+iyb81/xBCth8oc+s1p1LaM2CouGLQgIuFC49rkuNYyLDUV4ctvzOFEtqdrpX1+JZBCml89a\npYvcM049ZIYLSQhtosGErokRIAkCRq28qmsoWhQYvA+D9PNh+Jmk+MFemrDDzz79naHSWSmPURsl\n2PlJghTs4cEGmPaGIqlLyoqVmqGac5R3jQzMZG8N0Vqj8FykgiaYhqGh1ELxwiuB1G/RnFDk7ydY\nNrRepWvWoS8lUmNo8Z1VuPa4Lp/4tpKdcDFufHTsc2EbQXrlNUpJBrSKBXw5eRRB4Y+398Oyk1wu\ngX8mM0CgK3JhiG3EIEmBttYm86iIVgoM3358VrdUCPSOzcaozjcGXJefcfo7Q5tzjvJoKQ1cXleL\nS1ReCaRmvOYGL5qkcX/WXSFXNbxULiVr/Zeitaa2ZwY8hr0VBzDwF/BFKVx7XJcCIymMq0HkgmQt\naz8NIz5aI8t81tIaGHSi+Sc90YiTkrJYqQkGPD1qVMC1WiswWjtb899PRYnfH8i6S9AOpCF+oU1S\nLV9TFeLYqwFpoEBLpbxvSBzY7O+tgVu73YFGT+Aa6Fea1prCkg1JSDYktbyjgkIQrkuTVNAoD5LB\ntBGirXvCsM6yOkvcetrioKShxcFq4cOD8Nx9/YNeix+sW4o7D+b0ljXPZ/7yH+SlA7lUY5CaHEST\nVKCjsyV438nVKDDkGkbHh3dOSBsDgFs/4mqlNf4LBYXL5brSMHYfr0DnuPCgBfe0WmDs4GR8v923\ngWACNAxaMqPVSArfRUdooVGFHrAWj1jYYjJaixoGQQTdL0IbqEkAwZ3e0t8Xj1jY/PWEywa/7tWA\nLHGvHesyhWJk8nDkxve7KmsGAcDfR7z8i/SDwvXL1TcKdBAWuwsffsetijVheqDAYDsfxpGabOG7\nw+uUVX/VUhrZoCmtlKpRN/8nbc1KV6H+6FpKA4fXKeQS+O8XFaKAm8zp7Zt1EhKFsrWrb/ECg8HV\nZ7eXmaR+AQ2DIIirVlgArSvHraDQFq4bgeFwS0t6B3d+fnzi/4TP58wXZL/paJ3MHCQ1SbW0BkBr\nCPOJaOkAACAASURBVKWBPDXgUewqy0NuAudIlwktUgWtSouHsu9Bo9+aFTpaC5qg4GG9wjHJ4YmY\nlDYWWSGWkQwG79y/Gh29MpOUMrNWUOhwrnmB4fYwUNEkHE5RSJTUNLR4HL9GL0+YSi8brLWq9rUV\nEyEERlJ4Au7oMVn4Lh0keS1iYEK/gOO42bABVpdV8FUQBIFxqaMD9m2+Xb6lYf1WobsakIXVXoUm\nMwWFa41r2um9+3gFfv/3bThxsQ4Ol6hhnCura+YoDtbPBKOndbLBWku3rxO4tQllJEEGXew+GCkR\nnZEYZKH7S4G/1tVokqJbmemtoKDQPlzT07LvdhcAALYfKcONfRK50FnCK5T8YL0UtxBSK9CpdDKn\nt1atgpCk0Q6QAdWLQsOvatfSalqzet4bIPguFT5K6qrUMCRC4kplWCsoXE9c0wJDit3phDZnGwja\nDdbLDTSsR9VqgRFG62UmEJqkMWt8ulCBtq2EMkk1R0sO2PZwBMfoolFiK7uiWcyhkN5fe5YKV1BQ\nCM51IzAsDjsImouOEoSE34p4BAjZjJwmaSE7WK/SyWaxNEnjhj6JaC8uRcPg8V+voiOYkTEVJl0M\nbk0Z2eHXulQUrUJB4ZfluvnHNbnk5iOWIQG/Nbf91zKWRkX51w9qb5v55WgYWdHd27UNwYhQGzC1\n24QWFz1SUFC49rluBQYYEp4KMbO7d2x2gMCQRkX5F2Vr77j/SzGpdI9KQ7gqTFj0RkFBQeGXoMNN\nUj/99BMWLVoElmUxffp0zJ49W/Z7eXk5nnvuOVitVjAMg3nz5uGmm25ql2uzhAf8ehNNbr9kPYZC\ndlQvPDLiTqF0xuv73gVgFnaROrlV5KXXYboULsUk9Xi/3wuObwUFBYVfig7VMBiGwcKFC7F8+XJ8\n9913WLduHfLz82X7fPDBBxg/fjxWr16Nt956Cy+//HK7XLvR3YiGtG+hSjsKsCwcnkCTVFyUDhpa\nDYIgQBBEwBKl0sJ+ejr4uhNtha//dCkmH4IgrvulMRUUFH55OlTDOHr0KFJSUtCpE+ecnTBhAjZv\n3oz09HRhH4IgYLNxa0xbLBbEx7ctb4CnxMYtpUrHlsNTx8LrdsvuNkKnxR3D02XH3J1xO0z6GKy9\nsBEAV8jt2YF/xMWGIsToomX7tlexu+cGzsWR6uPIjslol/Ndb/y218wW63QpKCi0Dx0qMCorK5GY\nKEYSxcfH49ixY7J9HnvsMcyaNQufffYZHA4HPv7443a5tsVlFT5X06cR5omS/R4drg8oGKim1BiX\nOhqbCrfC6XWBJmikRHRGSkTngPO3VzG+hLA4JIQFlkVXaB394npf6SYoKFw3dKjAaE39oXXr1mH6\n9Ol48MEHcfjwYTzzzDNYt25di8eZTM0nrTmq7cLnuoj9sJcMBiSH6DTakOfQqrRwel3QazUh94mN\njoApuvk2/FK01BfXE0pfiCh9IaL0RfvQoQIjISEBZWVlwvfKykrExcXJ9lm1ahWWL18OAMjJyYHT\n6URdXR2io+UmIH+qq63N/l5YUyb7bnXZIXVbE14y5DlI1ldwz0OE3MdS70C1t/k2/BKYTIYW++J6\nQekLEaUvRJS+EGmr4OxQ42/v3r1RVFSE0tJSuFwurFu3DqNHy4vfJSUlYdeuXQCA/Px8uFyuFoVF\na6h1mGXfNTp5RrfUoe0PnxDWnNlJSRpTUFC43uhQDYOiKCxYsACzZs0Cy7K44447kJ6ejvfeew+9\ne/fGyJEj8dxzz+GFF17AJ598ApIk8cYbb7TLtZ1eeVRUdrcwnBCtVOgcHjpLmmqFwGhrjSYFBQWF\nXxsdnocxYsQIjBgxQrZt7ty5wuf09HT85z//affrerxyjYJWuwGJwMhsZk0IoRx4u7dKQUFB4dfL\nNWtXaWiULyjkJpwAAJMuBrG6GKQGiXziEUp6B0mOG5k8XDiPgoKCwvXENVt80O50Qerltrs59WJq\nt4noa+rZ7LG8wPAGERh39JiM6d0nKdVRFRQUrjuuSQ3D4fLAw8hNUjafwGhN/kRzGgaglNJWUFC4\nPrkmBUZJlR0g5E5pm5vLJm9NlVmqBYGhoKCgcD1yTQqM4iorCEI+2PNRU63RMAhFYCgoKCgEcE0K\nDLPNGaBh8LSmBhTVjA9DQUFB4XrlmhQYDTYXQLAwaePw0pBnZb+1hw9DQUFB4Xrk2hQYdhdAMFBT\ndLOLIoWiiyEZAJBsSGphTwUFBYXrh2syrLbB7gKMLGiSChQYRMu3PL7rGMSHxaGfSamEqqCgoMBz\nTQoMi90FgmBAkRRokoaaUsN1CU5vNaXCDYkDO7qZCgoKCr8qrjmTFMOysNidACE6r/W0uB63Slmp\nTkFBQeGyuOYEhr3JLUQ38cuoSgVGey18pKCgoHC9cc0JDN7hDUBY91q6XnZ7rcWtoKCgcL1xjQoM\nLgfDX8NQkbQgRBQUFBQULo1rTmBYbFKBwd2emuKqEOppfcjjFBQUFBSa55oTGMFMUg4vV+pcr9KF\nPE5BQUFBoXk63AP8008/YdGiRWBZFtOnT8fs2bNlv7/22mvYu3cvCIJAY2MjzGYz8vLyLvt69TYn\nCD+TVKO7CYDc+a2goKCgcGl0qMBgGAYLFy7EJ598gri4ONxxxx0YPXo00tPThX3mz58vfP78889x\n6tSpNl3TItEw+BIf3aPSkN9QgF6xWW06t4KCgsL1TIcKjKNHjyIlJQWdOnHrZ0+YMAGbN2+WCQwp\n3333HR5//PE2XVPu9OYExviuY5Ae1RVZzSzLqqCgoKDQPB3qw6isrERiYqLwPT4+HlVVVUH3LSsr\nQ2lpKYYMGdKma9bbHdBm7QMg+jAokkJ2TIay8JGCgoJCG+hQDYNlg5cYD8a6deswduzYVg/qJpMh\n6Hartw5Qcet3h+t1Ife7lrge7rG1KH0hovSFiNIX7UOHCoyEhASUlZUJ3ysrKxEXFxd03/Xr1+Ol\nl15q9bmrq60B29weBvZGBny5QZfDG3S/awmTyXDN32NrUfpCROkLEaUvRNoqODvUJNW7d28UFRWh\ntLQULpcL69atw+jRowP2u3DhAiwWC3Jyctp0PWujS/ad92EoKCgoKLSdDtUwKIrCggULMGvWLLAs\nizvuuAPp6el477330Lt3b4wcORIAp11MmDChzdeTOrwBgFSyuhUUFBTajQ7PwxgxYgRGjBgh2zZ3\n7lzZ98cee6xdrtVgkwsMpW6UgoKCQvtxTdls6u1OIQcDUExSCgoKCu3JNTWiXii1yE1SisBQUFBQ\naDeumRGVYVkcvVALvVa8JTfjuYItUlBQULi2uGYExsot52Gxu9A1SQwbczPuK9giBQUFhWuLa0Zg\n7G/4Cepuh3BzPzGz3O1VBIaCgoJCe3HNrFfaFHUGFAC1WswUVzQMBQUFhfbjmtAwjteIFW6bPA7h\ns0sRGAoKCgrtxq9eYJRbq/HB0Y+F702eJuFzj6jgVXEVFP6/vTsPbKpKHz7+TdK0LC2b3QCZikVB\nsAqoLMKUdYChBVoBFas4U6SAQNlEFgXGqQNYmAr8FBVBQUBRXwGFMOpYQUAqKIIwLDrgQGmRlq3Q\njaTJPe8fLSmhQFJoUtM+n79yb05Ozn2g98k5595zhRDl5/UJI+PsRYftgpIeRmTjjrQLbVsZTRJC\niCrJacLIysryRDtumk6vOWxv+PVzAG73byTLmQshRAVymjAGDhzI2LFjSUtL80R7yu16V0LJOlJC\nCFGxnCaMr7/+mh49erBgwQL69u3L6tWrycvL80TbXHLJZr7mflkWRAghKpbTs6qvry8xMTF8+OGH\nvPzyy7z99ttERkaSlJTE2bNnPdHGGzJbr93DkIQhhBAVy6WzamZmJv/85z+ZNGkSHTt2ZOnSpdx2\n220MGzbM3e1zymIrfgaGKvJ12C8r1QohRMVyeuPeyJEj+eWXX3j88cdZu3Yt9evXB6Bt27Zs2rTJ\n7Q10xlwyh6Hl18VQ77R9vyw8KIQQFctpwhgwYAC9evXCYCj7i33jxo1Ov2Dr1q3Mnj0bpRQDBw4k\nISGhTJlNmzbx+uuvo9frad68OfPnz3ex+Vf2MIwO+w0y6S2EEBXKacKoW7cuBQUFBAQUL+p38eJF\nDhw4QMeOHZ1WrmkaSUlJLF++nODgYAYNGkSPHj0IDy+9oe748eMsXbqUDz/8EH9/f86dO1euA7CU\n9DCU1c9hvwxJCSFExXI6bpOcnIy/v79929/fn+TkZJcq37dvH2FhYTRu3Bij0UhUVBSpqakOZT76\n6COeeOIJ+3c0aNCgPO2nSCt5jvdVcxgyJCWEEBXL6VlVKeVwA5xer8dms7lUeVZWFg0blq4eGxIS\nQnZ2tkOZY8eO8b///Y8hQ4bw+OOPs23bNlfbDpSuF9Whxe0O+6WHIYQQFcvpkFTt2rX56aefuP/+\n+wH46aefqFWrlkuVK6WclrHZbKSnp7N69WpOnjxJXFwcJpPJoVdzI5dv3PMzOA5JSQ9DCCEqltOE\nMXnyZEaPHk2zZs0AOHLkCK+99ppLlYeGhnLy5En7dlZWFsHBwQ5lQkJCaNOmDXq9nttvv52mTZty\n7Ngx7r333hvWHRRU8qAkHw3McGeDO6gd/Ef+fbS4hxLYIICgBgE3qKHqsMdCSCyuILEoJbGoGE4T\nRps2bTCZTOzduxelFG3atKFu3bouVR4REUF6ejqZmZkEBQVhMplISUlxKNOzZ09MJhMxMTGcO3eO\n48eP06RJE6d1nz6dC0ChuXixwaJLiu53dbUnjIsXLnHalutSO71ZUFCAPRbVncSilMSilMSi1K0m\nTpceoFS3bl26dOlS7soNBgMzZswgPj4epRSDBg0iPDycRYsWERERQbdu3fjjH//It99+S1RUFAaD\ngeeff97lhARgVcXP7a7h44tRX3o4cqe3EEJULKcJ4/Dhw8yaNYvDhw9jsVjs+w8dOnSDT5WKjIwk\nMjLSYV9iYqLD9tSpU5k6dapL9V3NqornMHx9jPjoSg9HL5PeQghRoZz+DP/b3/7G+PHjCQsL45tv\nviEhIYEJEyZ4om0usSorStPhZ/DBx6GHIQlDCCEqktOEYbFY6NixI0opgoODmTBhQrkvfXUnqyoC\nzYDBoHe4/NeglyEpIYSoSE7PqvqSE2/dunU5fPgw58+fJzMz0+0Nc5WN4oThY3A8FOlhCCFExXI6\nhxEVFcX58+dJSEhgyJAhaJpWZg6iMtlUEcrmg4/B8el6ch+GEEJUrBsmDE3T6NixI/Xr1ycyMpJd\nu3ZhNptdvqnOE2xYQfPFUKaHIQlDCCEq0g3Pqnq9nhdeeMG+bTQaf1fJQlMams6KshnK9DBkSEoI\nISqW05/h4eHhZGRkeKIt5VakFd+DgWbAoJchKSGEcCencxjnzp2jf//+PPDAAw5rSC1cuNCtDXOF\nueR53sVzGI4JQhKGEEJULJcmvaOiojzRlnIzW0tuJLziKqm764XzS85Rh0tshRBC3DqnCSM2NtYT\n7bgpFq00YVwekkpsk4CmtEpslRBCVE1OE0ZiYuI1f63/voakSnsYOp1OJryFEMINnCaMbt262V+b\nzWa++OILh0esViaz7XIPwweDQYaghBDCnco9JPXII48watQotzWoPC4nDJ1mQC9zFkII4VblvpRI\np9P9bi6ztZQkDKPeWMktEUKIqq9ccxhKKX7++Wc6duzo9oa54pK1eA4joIZrj4wVQghx88o1h2Ew\nGIiPj6d169ZubZSrLhQUAFDPxWeMCyGEuHluv6x269atzJ49G6UUAwcOJCEhweH9devWkZycTGho\nKABxcXEMGjTIpbovFBQCUN+/5i21UQghhHNO5zCGDBnChQsX7Ns5OTnExcW5VLmmaSQlJbFs2TI2\nbtyIyWTi6NGjZcpFRUWxbt061q1b53KyALhYWDwk1cBfehhCCOFuThNGQUGBwzO269WrR15enkuV\n79u3j7CwMBo3bozRaCQqKorU1NQy5ZRS5WhyqcKi4oRRt1aNm/q8EEII1zlNGJqmUVAyVwCQn5+P\nzWZzqfKsrCwaNmxo3w4JCSE7O7tMuS+//JIBAwYwbtw4Tp065VLdAFatuB1+Bqcja0IIIW6R0zNt\ndHQ08fHxDBkyBIAPPviA/v37u1S5Kz2H7t27Ex0djdFoZM2aNUyZMoUVK1a4VL+1ZLVaPx8/l8oL\nIYS4eU4TxogRIwgODubrr79GKcXjjz9OTEyMS5WHhoZy8uRJ+3ZWVhbBwcEOZa4c7nr00UeZP3++\nS3UHBQWAQYENgm+rU7xdTVXnY7+axKKUxKKUxKJiuDSWExsbe1NXS0VERJCenk5mZiZBQUGYTCZS\nUlIcypw+fZqgoCAAUlNTadasmUt1nz6di7nIAnowFxRx+nRuudtXFQQFBVTbY7+axKKUxKKUxKLU\nrSZOp3MYY8eOJScnx759/vx5xo0b51LlBoOBGTNmEB8fT3R0NFFRUYSHh7No0SI2b94MwMqVK4mO\njiYmJoZVq1YxZ84clxtvU8VzGDWMcqe3EEK4m9MexokTJ6hXr559u379+qSnp7v8BZGRkURGRjrs\nS0xMtL+eOHEiEydOdLm+K11OGH4+vjf1eSGEEK5z2sOw2WwOV0UVFRVhsVjc2ihXaap40lt6GEII\n4X5OexidO3dmwoQJDB06FIAVK1aU6TFUFhsyJCWEEJ7iNGFMnDiRt956i7lz5wLFa0u1b9/e7Q1z\nhYYNpekx+sgDk4QQwt2cDkkZjUbGjBnD66+/zp/+9Cc+++wzpk+f7om2OaVhA01vfzyrEEII97lh\nD8NqtfL111/zySefsHfvXqxWK8uWLfvdrFar0EDpr/kIWSGEEBXruj2MOXPm0LVrV9asWUN0dDTf\nfPMNdevW/d0kCwCFDZ0q9zOghBBC3ITr9jA++OAD2rRpQ0JCAh06dAD43f2SVzoNNJm/EEIIT7hu\nwti+fTsbNmwgOTmZCxcuEBMT4/Kig56idDZ0yBVSQgjhCdcdz6lTpw5xcXGsXbuW119/nQsXLnDp\n0iXi4uJYs2aNJ9t4fToNPdLDEEIIT3BpAqBFixa8+OKLbNu2jbi4uGs+06JS6DR0ShKGEEJ4Qrke\nJGE0Gunbty99+/Z1V3tcpikNdEp6GEII4SFee4lRka0IQBKGEEJ4iNcmjEtFJQlDJwlDCCE8wWsT\nRmFR8QKIBulhCCGER3htwrhklR6GEEJ4ktcmDHPJkJSPrlzz9kIIIW6S2xPG1q1b6dOnD71792bJ\nkiXXLff555/TokULDhw44FK9l6wlQ1J66WEIIYQnuDVhaJpGUlISy5YtY+PGjZhMJo4ePVqmXH5+\nPqtWrSrXOlXSwxBCCM9ya8LYt28fYWFhNG7cGKPRSFRU1DVv+lu4cCHDhw/HWI4HIZlLLqv1kR6G\nEEJ4hFsTRlZWFg0bNrRvh4SEkJ2d7VDm0KFDnDp1ii5dupSrbnPJpLdBLz0MIYTwBLeebZVSTt+f\nPXs2r7zyisufuczoV5zravv5ERQUcPONrAKq+/FfSWJRSmJRSmJRMdyaMEJDQzl58qR9Oysri+Dg\nYPt2fn4+R44c4amnnkIpxZkzZ3j22Wd54403aNWq1Q3rPncxHwBl03H6dK57DsALBAUFVOvjv5LE\nopTEopTEotStJk63JoyIiAjS09PJzMwkKCgIk8lESkqK/X1/f3/S0tLs20899RTTpk2jZcuWTusu\nKhmSMsqQlBBCeIRbz7YGg4EZM2YQHx+PUopBgwYRHh7OokWLiIiIoFu3bg7ldTqdy0NSFs0KgNEg\nCUMIITzB7WfbyMhIIiMjHfYlJiZes+x7773ncr0W2+UehjxASQghPMFr7/S2lvQwfKWHIYQQHuG1\nCaPIVpIwfKSHIYQQnuC9CUN6GEII4VFemzAuD0n5SQ9DCCE8wusThgxJCSGEZ3hvwlDFCaOGJAwh\nhPAIr00YNs0GgJ+PbyW3RAghqgevTRiXexgyhyGEEJ7htQnDpop7GDXKsSS6EEKIm1cFEoYMSQkh\nhCd4bcLQkB6GEEJ4kvcmjJI5jJoy6S2EEB7hvQkDDaXA6CN3egshhCd4ccKwgfLa5gshhNfx2jOu\n0mnoNENlN0MIIaoN700Y0sMQQgiPcvsZd+vWrfTp04fevXuzZMmSMu+vWbOGfv36ERMTQ1xcHEeP\nHnWpXqWThCGEEJ7k1jOupmkkJSWxbNkyNm7ciMlkKpMQ+vXrx4YNG1i/fj3Dhg1jzpw5LtWtdBo6\nJUNSQgjhKW5NGPv27SMsLIzGjRtjNBqJiooiNTXVoUzt2rXtrwsKCtDrXWySTkPnvSNqQgjhddx6\nTWpWVhYNGza0b4eEhLB///4y5VavXs3y5cuxWq2sWLHCtcp1NulhCCGEB7k1YSilXCoXFxdHXFwc\nJpOJxYsXM3fuXOcf0mnodQaCggJusZXeT2JQSmJRSmJRSmJRMdyaMEJDQzl58qR9Oysri+Dg4OuW\n79u3L7NmzXJar02zgQ50Ss/p07kV0lZvFRQUUO1jcJnEopTEopTEotStJk63TgJERESQnp5OZmYm\nFosFk8lEjx49HMocP37c/nrz5s3ccccdTustshUByByGEEJ4kFt7GAaDgRkzZhAfH49SikGDBhEe\nHs6iRYuIiIigW7durFq1irS0NIxGI3Xq1OGVV15xWq+lJGHokTkMIYTwFLcvxBQZGUlkZKTDvsTE\nRPvrF154odx1mq2SMIQQwtO8ckznUpEFkIQhhBCe5JUJw1xUvLS5JAwhhPAcr0wYl6zFPQyDThKG\nEEJ4ilcmDPschiQMIYTwGO9MGCVzGAYZkhJCCI/xzoRhLZ7DMOjkaXtCCOEpXpkwLPY5DEkYQgjh\nKd6ZMGyXexgyJCWEEJ7ilQnDXHKnt49eehhCCOEpXpkwikqukpIehhBCeI5XJozLQ1LSwxBCCM/x\n0oRRPOktCUOI6iMvL4916/7fTX32+efHk5+fV8Etqn68MmEUXe5hyJCUENVGbu5F1q37+JrvaZp2\nw88mJy+gdm1/dzTrlrn6oLnfA6/8iX55eXOjwVjJLRFCeMqbb77GyZOZxMfH8eCD7enYsRPvvvs2\nt90WyJEjv7By5UdMm/Ycp09nY7GYGTx4CP36xQAweHB/li1bSUFBAc89l0hERGv+85+fCAoKYe7c\nf+Lr6+vwXd9+u40VK5ZhtVqpW7cuM2e+TP369SksLOTVV5P5+edD6HR6/vrX4XTp0o3vvtvBkiWL\n0TSNevXqsWDBYt55Zwm1atXi8cefBGDo0MdITl4IKJ57LpE2bR7kwIH9zJkzn5Url/Pzzwcxm810\n7dqD+PgEAA4dOsCiRf+ksPASvr6+LFiwmMmTxzFhwvM0a3YXAKNGDWPy5GnceWczt/8beHfC0EvC\nEKIyfPT1Eb4/nF2hdT7UIphHu1//pDdq1FiOHfuVd95ZDcCePbs5dOggK1d+RGhoKADTp88iICAA\ns9nM8OFD6dKle8lT5nT2ejIyTvDSS3OYMuUFZs6cxpYtX9OrVx+H77r//jYsWbIcgI0b1/P+++8x\nevQ4li9fSkBAACtWrAGKh8lycnJITv4HixcvIzQ0lNzcaz/dT6crbcOJE+m88MLfmDRpCgAjRowm\nICAATdMYN24Uv/56hD/84Q5mzZpOUtIrNG/egoKCAvz8/OjXL4ZNmz4jMXESJ06kY7UWeSRZgJcm\njCKteEjKKHMYQlRrLVu2sicLgI8+ep9t274BIDs7m4yMdMLDGwOlwz4NGzYiPLz4BNu8eQtOnTrJ\n1bKzTzFz5gLOnj2D1WqlYcNGAPzwwy7+/vc59nL+/v58++022rRpa29HQMC1H4N65dBTSEgo99zT\nyr6dmvoFn322HpvNxrlzZ/nf//4HQGBgEM2btwCgVq1aAHTr1oPly5cxevR4TKbP+POf+7kYrVvn\n9jPu1q1bmT17NkopBg4cSEJCgsP7y5cv5+OPP8bHx4cGDRowe/ZsGjZseMM6i0omvWVISojK8Wj3\nZjfsDXhKjRo17K/37NnNjz/+wJIly/H19WXs2BFYLJYyn7ly+EmvN1yzzKuvzmPIkKd4+OHO7Nmz\nm3fffRu49nzD9eYgDAYDmlb63pXfU7NmTfvr3347yZo1q1m2bCW1a/sze/ZLWCxmrje14edXg4ce\nas+2bVvYvPkrli5dee2CbuDWSW9N00hKSmLZsmVs3LgRk8nE0aNHHcq0bNmStWvX8umnn9KrVy+S\nk5Od1nu5h+ErQ1JCVBu1atWioKDguu/n5+cREBCAr68vx48f48CB/1yznCuTzPn5+QQGBgLwr39t\ntO9v164Dn3zyoX07NzeXe++9j71793Dq1G8AXLx4ESjuyfzyy2EAfv75ML/9VtqTubIN+fn51KxZ\nk1q1anPu3Fm++24HAGFhd3D27BkOHz4EQEFBgX1yPzp6AAsWzOeee1pdt0fjDm7tYezbt4+wsDAa\nN24MQFRUFKmpqYSHh9vLtGvXzv66devWbNiwwWm9Vq14DsNXehhCVBt16tQlIuJ+nn76cdq3f5iO\nHTs5vN++/cOsX/8Jf/nLE/zhD2Hce2/EFe+Wzh9cOZdwPfHxw3nxxSkEB4fQsuW99mTw9NPDSEl5\nhaFDH8NgMPDXvyYQGdmV559/genTn0MpRf36DUhJeY0uXbrz+ecm4uPjaNGiJU2ahF2zDc2a3cVd\ndzXnqaceo1Gjxtx33/0A+Pj48NJLc3j11WTMZjM1atRgwYLF1KhRg+bNW1C7dm2iojw3HAWgU268\npuuLL75g+/btJCUlAfDpp5+yf/9+XnzxxWuWT0pKIigoiJEjR96w3omfzifj0lGGhIyhc6s/VHi7\nvUlQUACnT197kq26kViUkliUqoqxOHPmNImJI3n//U/K9bniCwBunlt7GOXJRZ9++ikHDhxg5Urn\n43GXexiB9evccgCqAolBKYlFKYlFqaoUi/Xr17Nw4UKmTZvm8eNya8IIDQ3l5MnScbusrCyCMd1e\n6gAAERdJREFUg4PLlNuxYwdLlixh1apVGI3Oh5msyopSUJhXVOV+OZRXVfz1dLMkFqUkFqWqWiw6\ndepBp049AMp9XLeaYNw66R0REUF6ejqZmZlYLBZMJhM9evRwKHPw4EFmzZrFG2+8Qf369V2q16pZ\nQTPg6yt3egshhKe4tYdhMBiYMWMG8fHxKKUYNGgQ4eHhLFq0iIiICLp168a8efMoLCxk3LhxKKVo\n1KgRixcvvmG9xQlDTw1JGEII4TFuvw8jMjKSyMhIh32JiYn21++++26567SqIpRmoIZREoYQQniK\nVy4+aFNWUHr8pIchhBAe45UJQ8MGmoEavrI0iBDVxa0sbw7w0UcfYDabK7BF1Y+XJgyZwxCiurnR\n8uau+PjjDzCbL1Vgi8rPZrNV6vffKq/8ia50GigDPgavzHdCiJtw9fLmzz6byPvvr2Tz5n9TVGQl\nMrIr8fEJXLp0iZkzp3L6dDaapjF27BiOHcvgzJnTjB07knr16rFw4RsOdS9fvpRvv92GxWLm3nvv\nY/Lk6QBkZmYwb95scnJyMBgMJCXNpVGjxqxevYIvv/wXer2eDh06MWLEaMaOHcGYMRNo3rwFFy7k\n8MwzQ/n448/41782smPHdiwWM5cumZk7959MnTqJvLxcrFYrw4ePpHPnLkDxMiRr1qxGr9cRHn4X\nEydO4emnh7BmzVoMBgMFBfkl2+swGDz/g9krEwaAXknvQojKsvbIRvZk76/QOtsER/BIs+jrvn/1\n8ubff/8dGRnpvP32eyilmDJlIj/9tJecnHMEBgaRnLwAgJo1dTz4oOLDDz/g//7vLerUqVOm7oED\nH+Mvf3kGgKSkmezYsZ2HH+7MSy+9yNChf6Vz5y4UFRWhaRrffbeD7du38vbb7+Hr63vd5cyvXI7k\nwIH9vPfeh/j7+6NpGnPmzKdWrVpcuJDDiBHF9f/661FWrVrOG2+8Q506dcjNzaVWrVq0bfsAaWnb\n6dy5C1999SVdu/aolGQBXpwwDPK0PSGqtV27dvL997uIj49DKUVh4SUyMtK5777WvP76Qt588zU6\nduxMz55/pLAwl+Ilzq+9+sTu3bt4//2VmM2XyM3N5c47w2ndui1nzpy2//q/fFPxDz/sIiqqn33V\nW1cW/3voofb4+xc/8U/TNN566zX27t2DXq/jzJnTnD9/jj17fqBr1x72hHa53ujoAbz//ko6d+7C\npk0bmDLl2ksreYLXJgy9ThYeFKKyPNIs+oa9AU9QSvHUU3+hf//YMu8tW7aKtLRveeut1/jll/0M\nHvzUdeuxWCykpCTzzjurCAwM4p13lpQsRX7t5FK85FHZBQwNBgNKafY6r3Tlcub//vfn5OTk8O67\nq9Hr9Qwe3B+z2XLdpZQiIu7n1KlX2Lv3RzRNo2nTO697LO7mtZMAvqqm80JCiCrj6uXN27fvgMn0\nGYWFhQAlv9TPc+bMGfz8/OjVqw9DhjzJwYMHSz5fm/z8/DL1WiwWdLri1XALCgrYsiXVXj44OIRt\n27YAUFRUhNl8iXbtir/38gR66XLmjTl8uPi7Nm/+6rrHkZeXR/36DdDr9fz44w/2lXAfeKAdmzd/\nxcWLFxzqBejduy9/+9sLREX1L3/gKpBX9jDMhx/iD/XvqOxmCCE86OrlzZ99NpFjx44xcuRfgeKE\nMmNGEhkZJ3j99YXo9Tp8fIz84x/Fq2X37x/Dc88lEhgY5DDp7e/vT79+sQwd+hgNGzZyeBLeiy++\nxLx5s1m69C2MRiNJSXNp374jR478wrBhQ/H1NdKhQycSEp5lyJA4ZsyYxhdf/IsHHnjousfRq1cf\npkyZyPDhQ2nWrDlhYU0BaNr0ToYOjWfMmAQMBgN33dWc6dNnlXzmzyxd+iY9e/aq8LiWh1uXN3eX\nfpM+pc1dgYwdeF9lN6XSVbWF1W6FxKKUxKJUVYjF5s1f8e2323jxxZduqZ7f9fLm7iR3eQshqoMF\nC+bx3XdpzJ+/sLKb4r0JI6Cmr/NCQgjh5caPn1zZTbDz2knv+gF+ld0EIYSoVrw2YdTzlx6GEEJ4\nktcmDOlhCCGEZ7k9YWzdupU+ffrQu3dvlixZUub9H374gUceeYRWrVrx5ZdfulxvPX9JGEII4Ulu\nTRiappGUlMSyZcvYuHEjJpOJo0ePOpRp1KgRc+fOpV+/fuWqu570MIQQwqPcepXUvn37CAsLo3Hj\nxgBERUWRmppKeHi4vUyjRo0A0OnK3mp/PQ1vq42fPG1PCCE8yq09jKysLBo2bGjfDgkJITs7+5br\nfXVCl1uuQwghRPm4NWG46yby2jVl4UEhhPA0tw5JhYaGcvLkSft2VlYWwcHBFVL3rd7iXpVILEpJ\nLEpJLEpJLCqGW3sYERERpKenk5mZicViwWQy0aNHj+uW98JlrYQQotpw++KDW7du5R//+AdKKQYN\nGkRCQgKLFi0iIiKCbt26sX//fsaMGcPFixfx8/MjKCiIDRs2uLNJQgghboJXrlYrhBDC87z2Tm8h\nhBCeJQlDCCGESyRhCCGEcInXJQxna1NVNdOnT+fhhx92WDrlwoULxMfH07t3b4YNG0ZubunTxF5+\n+WV69erFgAEDOHToUGU02S1OnTrF0KFD6du3L/369eO9994DqmcsLBYLgwcPJiYmhn79+vHaa68B\nkJGRwaOPPkrv3r2ZOHEiVqvVXn7ChAn06tWLxx57zOFS96pC0zRiY2MZOXIkUH1j0b17d/r3709M\nTAyDBg0CKvhvRHkRm82mevbsqTIyMpTFYlH9+/dXR44cqexmudX333+vDh48qKKjo+37kpOT1ZIl\nS5RSSr311ltq3rx5SimltmzZooYPH66UUmrv3r1q8ODBnm+wm2RnZ6uDBw8qpZTKy8tTvXr1UkeO\nHKmWsVBKqYKCAqWUUlarVQ0ePFjt3btXjRs3Tm3atEkppdTMmTPVBx98oJRSavXq1WrWrFlKKaVM\nJpMaP358pbTZnd599101adIkNWLECKWUqrax6N69u8rJyXHYV5F/I17Vw7hybSqj0Whfm6oqe/DB\nB6lTp47DvtTUVGJjYwGIjY21xyA1NZWYmBgA7r//fnJzczlz5oxnG+wmQUFB3HPPPQDUrl2b8PBw\nsrKyqmUsAGrWrAkU/2K2Wq3odDp27txJ7969geJYfPXVV4Dj/5fevXuTlpZWOY12k1OnTvHNN98w\nePBg+77vvvuuWsZCKYWmaQ77KvJvxKsShrvWpvI2586dIzAwECg+kZ47dw6A7OxsQkND7eVCQkLI\nysqqlDa6U0ZGBocPH+b+++/n7Nmz1TIWmqYRExNDp06d6NSpE02aNKFOnTro9cV/0qGhofbjvTIW\nBoOBOnXqkJOTU2ltr2izZ8/m+eefty9gev78eerWrVstY6HT6Rg2bBgDBw7k448/BqjQvxGveqa3\nkltGbuha8SnPKsDeID8/n8TERKZPn07t2rWve3xVPRZ6vZ7169eTl5fH6NGjyzw2AEqP9+pYKKWq\nTCy2bNlCYGAg99xzDzt37gSKj+/qY64OsQBYs2aNPSnEx8fTtGnTCv0b8aqE4c61qbzJbbfdxpkz\nZwgMDOT06dM0aNAAKP6FcOrUKXu5U6dOVan4WK1WEhMTGTBgAD179gSqbywu8/f356GHHuKnn37i\n4sWLaJqGXq93ON7LsQgJCcFms5GXl0fdunUrueUV48cff+Trr7/mm2++wWw2k5+fz+zZs8nNza12\nsYDiHgRAgwYN6NmzJ/v27avQvxGvGpIq79pUVcXVvwS6d+/O2rVrAVi3bp09Bj169GD9+vUA7N27\nlzp16ti7olXB9OnTadasGU8//bR9X3WMxblz5+xXuly6dIm0tDSaNWtG+/bt+fzzzwHHWHTv3p11\n69YB8Pnnn9OhQ4fKabgbTJw4kS1btpCamkpKSgrt27dn/vz51TIWhYWF5OfnA1BQUMD27du5++67\nK/RvxOuWBrnW2lRV2aRJk9i5cyc5OTkEBgYyduxYevbsybhx4/jtt99o1KgRCxcutE+M//3vf2fb\ntm3UrFmTOXPm0KpVq0o+goqxe/dunnzySe6++250Oh06nY4JEyZw3333MX78+GoVi59//pmpU6ei\naRqaptG3b19GjRrFiRMnmDhxIhcvXuSee+5h3rx5GI1GLBYLkydP5tChQ9SrV4+UlBRuv/32yj6M\nCrdr1y7eeecd3nzzzWoZixMnTjBmzBh0Oh02m41+/fqRkJBATk5Ohf2NeF3CEEIIUTm8akhKCCFE\n5ZGEIYQQwiWSMIQQQrhEEoYQQgiXSMIQQgjhEkkYQgghXCIJQ3i1Rx99lNjYWKKiomjVqhWxsbHE\nxsYyffr0ctf1zDPPuLTc9bRp09i7d+/NNLdcDh48yBdffOH27xHCVXIfhqgSMjMzGTRo0A1XH728\nVIS3+Pjjj0lLSyMlJaWymyIE4GVrSQlRHmlpacybN4/WrVtz8OBBRo8ezblz51i9erX9gTpTp06l\nXbt2AHTp0oXly5fTtGlTnnjiCdq0acOePXvIzs4mOjqa8ePHA/DEE0/w7LPP0rlzZyZPnoy/vz9H\njx4lKyuLtm3bMmfOHKB4bZ7nn3+e8+fP06RJE2w2G927d+exxx5zaOeZM2eYNGkS58+fB6Bz5848\n88wzLF68mIKCAmJjY2nfvj1Tp05lz549pKSkUFhYCEBiYiKRkZGkp6fzxBNPEB0dze7du7FYLMya\nNYu2bdt6JNaimriVh3UI8XuRkZGhOnTo4LBvx44dqmXLlmr//v32fVc+XObIkSOqa9eu9u3IyEj1\n66+/KqWUGjJkiJo0aZJSSqmLFy+qdu3aqYyMDPt727ZtU0op9dxzz6knn3xSFRUVKbPZrPr06aN2\n7typlFJq1KhR6u2331ZKKXXixAnVpk0btWbNmjJtX7p0qZo5c6Z9++LFi0oppT766CM1ceJEh7bH\nxMSos2fPKqWUOnXqlIqMjFR5eXnq+PHjqnnz5spkMtmPvWvXrspqtboeRCGckB6GqNLuvPNO7r33\nXvv2sWPHWLRoEdnZ2RgMBrKzs8nJyaFevXplPvvnP/8ZgICAAJo2bUp6ejqNGzcuU+5Pf/oTPj7F\nf0otW7YkPT2ddu3asXPnTl5++WUAbr/9dntP5mqtW7dm1apVzJ8/n4ceeojOnTtfs9zu3bvJyMhg\n2LBh9gUpDQYDJ06coFatWtSsWZO+ffsC0LFjRwwGA8eOHSM8PNzVcAlxQ5IwRJVWu3Zth+0JEyYw\na9YsunTpgqZp3HfffZjN5mt+1s/Pz/5ar9djs9nKVc7V5yw88MADrFu3jh07dvDJJ5+wdOlSVq5c\nWaacUopWrVqxfPnyMu+lp6eX2adpWpV61oOofN4zAyiEE8qF6zfy8vLsq5OuWbPmukmgIrRr186+\nrHRmZia7du26ZrmMjAz8/f3p27cvU6dO5T//+Q9Q/KyLy8uYA7Rt25YjR47www8/2Pft27fP/rqw\nsJBNmzYBxY8oBQgLC6vYgxLVmvQwRJXhyq/p6dOnk5CQQMOGDWnfvj0BAQHX/PzVdV3vvRuVmzFj\nBlOmTMFkMnHnnXfStm1bh++7LC0tjffeew+DwYBSiqSkJAA6derEihUriImJoUOHDkydOpXFixcz\nb948cnNzKSoqokmTJrz55psABAYG8t///pfBgwdjsVhISUnBYDA4jYkQrpLLaoVwE7PZjNFoRK/X\nk5WVxeDBg1m9ejVNmjSp8O+6fJXU9u3bK7xuIS6THoYQbvLrr78ybdo0lFJomsaECRPckiyE8BTp\nYQghhHCJTHoLIYRwiSQMIYQQLpGEIYQQwiWSMIQQQrhEEoYQQgiXSMIQQgjhkv8PZHg4l1eLyCQA\nAAAASUVORK5CYII=\n",
+ "text/plain": [
+ "\u003cmatplotlib.figure.Figure at 0x7f96f7389490\u003e"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "#@test {\"timeout\": 90}\n",
+ "with context.eager_mode():\n",
+ " durations = []\n",
+ " for t in range(burn_ins + trials):\n",
+ " hp = tf.contrib.training.HParams(\n",
+ " learning_rate=0.05,\n",
+ " max_steps=max_steps,\n",
+ " )\n",
+ " train_ds = setup_mnist_data(True, hp, 500)\n",
+ " test_ds = setup_mnist_data(False, hp, 100)\n",
+ " ds = tf.data.Dataset.zip((train_ds, test_ds))\n",
+ " start = time.time()\n",
+ " (train_losses, test_losses, train_accuracies,\n",
+ " test_accuracies) = train(ds, hp)\n",
+ " if t \u003c burn_ins:\n",
+ " continue\n",
+ " train_losses[-1].numpy()\n",
+ " test_losses[-1].numpy()\n",
+ " train_accuracies[-1].numpy()\n",
+ " test_accuracies[-1].numpy()\n",
+ " duration = time.time() - start\n",
+ " durations.append(duration)\n",
+ " print('Duration:', duration)\n",
+ "\n",
+ "\n",
+ " print('Mean duration:', np.mean(durations), '+/-', np.std(durations))\n",
+ " plt.title('MNIST train/test losses')\n",
+ " plt.plot(train_losses, label='train loss')\n",
+ " plt.plot(test_losses, label='test loss')\n",
+ " plt.legend()\n",
+ " plt.xlabel('Training step')\n",
+ " plt.ylabel('Loss')\n",
+ " plt.show()\n",
+ " plt.title('MNIST train/test accuracies')\n",
+ " plt.plot(train_accuracies, label='train accuracy')\n",
+ " plt.plot(test_accuracies, label='test accuracy')\n",
+ " print('test_accuracy', test_accuracies[-1])\n",
+ " plt.legend(loc='lower right')\n",
+ " plt.xlabel('Training step')\n",
+ " plt.ylabel('Accuracy')\n",
+ " plt.show()\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "collapsed_sections": [],
+ "default_view": {},
+ "name": "Autograph vs. Eager MNIST benchmark",
+ "provenance": [
+ {
+ "file_id": "1tAQW5tHUgAc8M4-iwwJm6Xs6dV9nEqtD",
+ "timestamp": 1530297010607
+ },
+ {
+ "file_id": "18dCjshrmHiPTIe1CNsL8tnpdGkuXgpM9",
+ "timestamp": 1530289467317
+ },
+ {
+ "file_id": "1DcfimonWU11tmyivKBGVrbpAl3BIOaRG",
+ "timestamp": 1522272821237
+ },
+ {
+ "file_id": "1wCZUh73zTNs1jzzYjqoxMIdaBWCdKJ2K",
+ "timestamp": 1522238054357
+ },
+ {
+ "file_id": "1_HpC-RrmIv4lNaqeoslUeWaX8zH5IXaJ",
+ "timestamp": 1521743157199
+ },
+ {
+ "file_id": "1mjO2fQ2F9hxpAzw2mnrrUkcgfb7xSGW-",
+ "timestamp": 1520522344607
+ }
+ ],
+ "version": "0.3.2",
+ "views": {}
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/tensorflow/contrib/autograph/examples/notebooks/workshop.ipynb b/tensorflow/contrib/autograph/examples/notebooks/workshop.ipynb
new file mode 100644
index 0000000000..e8f16b431d
--- /dev/null
+++ b/tensorflow/contrib/autograph/examples/notebooks/workshop.ipynb
@@ -0,0 +1,1093 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "qWUV0FYjDSKj"
+ },
+ "outputs": [],
+ "source": [
+ "import tensorflow as tf\n",
+ "from tensorflow.contrib import autograph\n",
+ "\n",
+ "import matplotlib.pyplot as plt"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "kGXS3UWBBNoc"
+ },
+ "source": [
+ "# 1. AutoGraph writes graph code for you\n",
+ "\n",
+ "[AutoGraph](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/README.md) helps you write complicated graph code using just plain Python -- behind the scenes, AutoGraph automatically transforms your code into the equivalent TF graph code. We support a large chunk of the Python language, which is growing. [Please see this document for what we currently support, and what we're working on](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/LIMITATIONS.md).\n",
+ "\n",
+ "Here's a quick example of how it works:\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "aA3gOodCBkOw"
+ },
+ "outputs": [],
+ "source": [
+ "# Autograph can convert functions like this...\n",
+ "def g(x):\n",
+ " if x \u003e 0:\n",
+ " x = x * x\n",
+ " else:\n",
+ " x = 0.0\n",
+ " return x\n",
+ "\n",
+ "# ...into graph-building functions like this:\n",
+ "def tf_g(x):\n",
+ " with tf.name_scope('g'):\n",
+ " \n",
+ " def if_true():\n",
+ " with tf.name_scope('if_true'):\n",
+ " x_1, = x,\n",
+ " x_1 = x_1 * x_1\n",
+ " return x_1,\n",
+ "\n",
+ " def if_false():\n",
+ " with tf.name_scope('if_false'):\n",
+ " x_1, = x,\n",
+ " x_1 = 0.0\n",
+ " return x_1,\n",
+ "\n",
+ " x = autograph_utils.run_cond(tf.greater(x, 0), if_true, if_false)\n",
+ " return x\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "I1RtBvoKBxq5"
+ },
+ "outputs": [],
+ "source": [
+ "# You can run your plain-Python code in graph mode,\n",
+ "# and get the same results out, but with all the benfits of graphs:\n",
+ "print('Original value: %2.2f' % g(9.0))\n",
+ "\n",
+ "# Generate a graph-version of g and call it:\n",
+ "tf_g = autograph.to_graph(g)\n",
+ "\n",
+ "with tf.Graph().as_default(): \n",
+ " # The result works like a regular op: takes tensors in, returns tensors.\n",
+ " # You can inspect the graph using tf.get_default_graph().as_graph_def()\n",
+ " g_ops = tf_g(tf.constant(9.0))\n",
+ " with tf.Session() as sess:\n",
+ " print('Autograph value: %2.2f\\n' % sess.run(g_ops))\n",
+ " \n",
+ " \n",
+ "# You can view, debug and tweak the generated code:\n",
+ "print(autograph.to_code(g))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "m-jWmsCmByyw"
+ },
+ "source": [
+ "#### Automatically converting complex control flow\n",
+ "\n",
+ "AutoGraph can convert a large chunk of the Python language into equivalent graph-construction code, and we're adding new supported language features all the time. In this section, we'll give you a taste of some of the functionality in AutoGraph.\n",
+ "AutoGraph will automatically convert most Python control flow statements into their correct graph equivalent. \n",
+ " \n",
+ "We support common statements like `while`, `for`, `if`, `break`, `return` and more. You can even nest them as much as you like. Imagine trying to write the graph version of this code by hand:\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "toxKBOXbB1ro"
+ },
+ "outputs": [],
+ "source": [
+ "# Continue in a loop\n",
+ "def f(l):\n",
+ " s = 0\n",
+ " for c in l:\n",
+ " if c % 2 \u003e 0:\n",
+ " continue\n",
+ " s += c\n",
+ " return s\n",
+ "\n",
+ "print('Original value: %d' % f([10,12,15,20]))\n",
+ "\n",
+ "tf_f = autograph.to_graph(f)\n",
+ "with tf.Graph().as_default(): \n",
+ " with tf.Session():\n",
+ " print('Graph value: %d\\n\\n' % tf_f(tf.constant([10,12,15,20])).eval())\n",
+ " \n",
+ "print(autograph.to_code(f))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "FUJJ-WTdCGeq"
+ },
+ "source": [
+ "Try replacing the `continue` in the above code with `break` -- AutoGraph supports that as well! \n",
+ " \n",
+ "Let's try some other useful Python constructs, like `print` and `assert`. We automatically convert Python `assert` statements into the equivalent `tf.Assert` code. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "IAOgh62zCPZ4"
+ },
+ "outputs": [],
+ "source": [
+ "def f(x):\n",
+ " assert x != 0, 'Do not pass zero!'\n",
+ " return x * x\n",
+ "\n",
+ "tf_f = autograph.to_graph(f)\n",
+ "with tf.Graph().as_default(): \n",
+ " with tf.Session():\n",
+ " try:\n",
+ " print(tf_f(tf.constant(0)).eval())\n",
+ " except tf.errors.InvalidArgumentError as e:\n",
+ " print('Got error message:\\n%s' % e.message)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "KRu8iIPBCQr5"
+ },
+ "source": [
+ "You can also use plain Python `print` functions in in-graph"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "ySTsuxnqCTQi"
+ },
+ "outputs": [],
+ "source": [
+ "def f(n):\n",
+ " if n \u003e= 0:\n",
+ " while n \u003c 5:\n",
+ " n += 1\n",
+ " print(n)\n",
+ " return n\n",
+ " \n",
+ "tf_f = autograph.to_graph(f)\n",
+ "with tf.Graph().as_default():\n",
+ " with tf.Session():\n",
+ " tf_f(tf.constant(0)).eval()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "NqF0GT-VCVFh"
+ },
+ "source": [
+ "Appending to lists in loops also works (we create a `TensorArray` for you behind the scenes)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "ABX070KwCczR"
+ },
+ "outputs": [],
+ "source": [
+ "def f(n):\n",
+ " z = []\n",
+ " # We ask you to tell us the element dtype of the list\n",
+ " z = autograph.utils.set_element_type(z, tf.int32)\n",
+ " for i in range(n):\n",
+ " z.append(i)\n",
+ " # when you're done with the list, stack it\n",
+ " # (this is just like np.stack)\n",
+ " return autograph.stack(z) \n",
+ "\n",
+ "tf_f = autograph.to_graph(f)\n",
+ "with tf.Graph().as_default(): \n",
+ " with tf.Session():\n",
+ " print(tf_f(tf.constant(3)).eval())\n",
+ "\n",
+ "print('\\n\\n'+autograph.to_code(f))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "iu5IF7n2Df7C"
+ },
+ "outputs": [],
+ "source": [
+ "def fizzbuzz(num):\n",
+ " if num % 3 == 0 and num % 5 == 0:\n",
+ " print('FizzBuzz')\n",
+ " elif num % 3 == 0:\n",
+ " print('Fizz')\n",
+ " elif num % 5 == 0:\n",
+ " print('Buzz')\n",
+ " else:\n",
+ " print(num)\n",
+ " return num"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "EExAjWuwDPpR"
+ },
+ "outputs": [],
+ "source": [
+ "tf_g = autograph.to_graph(fizzbuzz)\n",
+ "\n",
+ "with tf.Graph().as_default(): \n",
+ " # The result works like a regular op: takes tensors in, returns tensors.\n",
+ " # You can inspect the graph using tf.get_default_graph().as_graph_def()\n",
+ " g_ops = tf_g(tf.constant(15))\n",
+ " with tf.Session() as sess:\n",
+ " sess.run(g_ops) \n",
+ " \n",
+ "# You can view, debug and tweak the generated code:\n",
+ "print('\\n')\n",
+ "print(autograph.to_code(fizzbuzz))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "SzpKGzVpBkph"
+ },
+ "source": [
+ "# De-graphify Exercises\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "8k23dxcSmmXq"
+ },
+ "source": [
+ "#### Easy print statements"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "dE1Vsmp-mlpK"
+ },
+ "outputs": [],
+ "source": [
+ "# See what happens when you turn AutoGraph off.\n",
+ "# Do you see the type or the value of x when you print it?\n",
+ "\n",
+ "# @autograph.convert()\n",
+ "def square_log(x):\n",
+ " x = x * x\n",
+ " print('Squared value of x =', x)\n",
+ " return x\n",
+ "\n",
+ "\n",
+ "with tf.Graph().as_default(): \n",
+ " with tf.Session() as sess:\n",
+ " print(sess.run(square_log(tf.constant(4))))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "_R-Q7BbxmkBF"
+ },
+ "source": [
+ "#### Now some exercises. Convert the TensorFlow code into AutoGraph'd Python code."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "SwA11tO-yCvg"
+ },
+ "outputs": [],
+ "source": [
+ "def square_if_positive(x):\n",
+ " x = tf.cond(tf.greater(x, 0), lambda: x * x, lambda: x)\n",
+ " return x\n",
+ "\n",
+ "with tf.Session() as sess:\n",
+ " print(sess.run(square_if_positive(tf.constant(4))))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "GPmx4CNhyPI_"
+ },
+ "outputs": [],
+ "source": [
+ "@autograph.convert()\n",
+ "def square_if_positive(x):\n",
+ " ... # \u003c\u003c\u003c fill it in!\n",
+ " \n",
+ "with tf.Session() as sess:\n",
+ " print(sess.run(square_if_positive(tf.constant(4))))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "qqsjik-QyA9R"
+ },
+ "source": [
+ "#### Uncollapse to see answer"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "DaSmaWUEvMRv"
+ },
+ "outputs": [],
+ "source": [
+ "# Simple cond\n",
+ "@autograph.convert()\n",
+ "def square_if_positive(x):\n",
+ " if x \u003e 0:\n",
+ " x = x * x\n",
+ " return x\n",
+ "\n",
+ "with tf.Graph().as_default(): \n",
+ " with tf.Session() as sess:\n",
+ " print(sess.run(square_if_positive(tf.constant(4))))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "qj7am2I_xvTJ"
+ },
+ "source": [
+ "#### Nested If statement"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "4yyNOf-Twr6s"
+ },
+ "outputs": [],
+ "source": [
+ "def nearest_odd_square(x):\n",
+ "\n",
+ " def if_positive():\n",
+ " x1 = x * x\n",
+ " x1 = tf.cond(tf.equal(x1 % 2, 0), lambda: x1 + 1, lambda: x1)\n",
+ " return x1,\n",
+ "\n",
+ " x = tf.cond(tf.greater(x, 0), if_positive, lambda: x)\n",
+ " return x\n",
+ "\n",
+ "with tf.Graph().as_default(): \n",
+ " with tf.Session() as sess:\n",
+ " print(sess.run(nearest_odd_square(tf.constant(4))))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "hqmh5b2VyU9w"
+ },
+ "outputs": [],
+ "source": [
+ "@autograph.convert()\n",
+ "def nearest_odd_square(x):\n",
+ " ... # \u003c\u003c\u003c fill it in!\n",
+ " \n",
+ "with tf.Session() as sess:\n",
+ " print(sess.run(nearest_odd_square(tf.constant(4))))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "b9AXIkNLxp6J"
+ },
+ "source": [
+ "#### Uncollapse to reveal answer"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "8RlCVEpNxD91"
+ },
+ "outputs": [],
+ "source": [
+ "@autograph.convert()\n",
+ "def nearest_odd_square(x):\n",
+ " if x \u003e 0:\n",
+ " x = x * x\n",
+ " if x % 2 == 0:\n",
+ " x = x + 1\n",
+ " return x\n",
+ "\n",
+ "with tf.Graph().as_default(): \n",
+ " with tf.Session() as sess:\n",
+ " print(sess.run(nearest_odd_square(tf.constant(4))))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "jXAxjeBr1qWK"
+ },
+ "source": [
+ "#### Convert a while loop"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "kWkv7anlxoee"
+ },
+ "outputs": [],
+ "source": [
+ "# Convert a while loop\n",
+ "def square_until_stop(x, y):\n",
+ " x = tf.while_loop(lambda x: tf.less(x, y), lambda x: x * x, [x])\n",
+ " return x\n",
+ " \n",
+ "with tf.Graph().as_default(): \n",
+ " with tf.Session() as sess:\n",
+ " print(sess.run(square_until_stop(tf.constant(4), tf.constant(100))))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "zVUsc1eA1u2K"
+ },
+ "outputs": [],
+ "source": [
+ "@autograph.convert()\n",
+ "def square_until_stop(x, y):\n",
+ " ... # fill it in!\n",
+ " \n",
+ "with tf.Graph().as_default(): \n",
+ " with tf.Session() as sess:\n",
+ " print(sess.run(square_until_stop(tf.constant(4), tf.constant(100))))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "L2psuzPI02S9"
+ },
+ "source": [
+ "#### Uncollapse for the answer\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "ucmZyQVL03bF"
+ },
+ "outputs": [],
+ "source": [
+ "@autograph.convert()\n",
+ "def square_until_stop(x, y):\n",
+ " while x \u003c y:\n",
+ " x = x * x\n",
+ " return x\n",
+ " \n",
+ "with tf.Graph().as_default(): \n",
+ " with tf.Session() as sess:\n",
+ " print(sess.run(square_until_stop(tf.constant(4), tf.constant(100))))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "FXB0Zbwl13PY"
+ },
+ "source": [
+ "#### Nested loop and conditional"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "clGymxdf15Ig"
+ },
+ "outputs": [],
+ "source": [
+ "@autograph.convert()\n",
+ "def argwhere_cumsum(x, threshold):\n",
+ " current_sum = 0.0\n",
+ " idx = 0\n",
+ " \n",
+ " for i in range(len(x)):\n",
+ " idx = i\n",
+ " if current_sum \u003e= threshold:\n",
+ " break\n",
+ " current_sum += x[i]\n",
+ " return idx\n",
+ "\n",
+ "N = 10\n",
+ "with tf.Graph().as_default(): \n",
+ " with tf.Session() as sess:\n",
+ " idx = argwhere_cumsum(tf.ones(N), tf.constant(float(N/2)))\n",
+ " print(sess.run(idx))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "i7PF-uId9lp5"
+ },
+ "outputs": [],
+ "source": [
+ "@autograph.convert()\n",
+ "def argwhere_cumsum(x, threshold):\n",
+ " ...\n",
+ "\n",
+ "N = 10\n",
+ "with tf.Graph().as_default(): \n",
+ " with tf.Session() as sess:\n",
+ " idx = argwhere_cumsum(tf.ones(N), tf.constant(float(N/2)))\n",
+ " print(sess.run(idx))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "weKFXAb615Vp"
+ },
+ "source": [
+ "#### Uncollapse to see answer"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "1sjaFcL717Ig"
+ },
+ "outputs": [],
+ "source": [
+ "@autograph.convert()\n",
+ "def argwhere_cumsum(x, threshold):\n",
+ " current_sum = 0.0\n",
+ " idx = 0\n",
+ " for i in range(len(x)):\n",
+ " idx = i\n",
+ " if current_sum \u003e= threshold:\n",
+ " break\n",
+ " current_sum += x[i]\n",
+ " return idx\n",
+ "\n",
+ "N = 10\n",
+ "with tf.Graph().as_default(): \n",
+ " with tf.Session() as sess:\n",
+ " idx = argwhere_cumsum(tf.ones(N), tf.constant(float(N/2)))\n",
+ " print(sess.run(idx))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "4LfnJjm0Bm0B"
+ },
+ "source": [
+ "# 3. Training MNIST in-graph\n",
+ "\n",
+ "Writing control flow in AutoGraph is easy, so running a training loop in a TensorFlow graph should be easy as well! \n",
+ "\n",
+ "Here, we show an example of training a simple Keras model on MNIST, where the entire training process -- loading batches, calculating gradients, updating parameters, calculating validation accuracy, and repeating until convergence -- is done in-graph."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "Em5dzSUOtLRP"
+ },
+ "source": [
+ "#### Download data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "xqoxumv0ssQW"
+ },
+ "outputs": [],
+ "source": [
+ "import gzip\n",
+ "import os\n",
+ "import shutil\n",
+ "\n",
+ "from six.moves import urllib\n",
+ "\n",
+ "\n",
+ "def download(directory, filename):\n",
+ " filepath = os.path.join(directory, filename)\n",
+ " if tf.gfile.Exists(filepath):\n",
+ " return filepath\n",
+ " if not tf.gfile.Exists(directory):\n",
+ " tf.gfile.MakeDirs(directory)\n",
+ " url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz'\n",
+ " zipped_filepath = filepath + '.gz'\n",
+ " print('Downloading %s to %s' % (url, zipped_filepath))\n",
+ " urllib.request.urlretrieve(url, zipped_filepath)\n",
+ " with gzip.open(zipped_filepath, 'rb') as f_in, open(filepath, 'wb') as f_out:\n",
+ " shutil.copyfileobj(f_in, f_out)\n",
+ " os.remove(zipped_filepath)\n",
+ " return filepath\n",
+ "\n",
+ "\n",
+ "def dataset(directory, images_file, labels_file):\n",
+ " images_file = download(directory, images_file)\n",
+ " labels_file = download(directory, labels_file)\n",
+ "\n",
+ " def decode_image(image):\n",
+ " # Normalize from [0, 255] to [0.0, 1.0]\n",
+ " image = tf.decode_raw(image, tf.uint8)\n",
+ " image = tf.cast(image, tf.float32)\n",
+ " image = tf.reshape(image, [784])\n",
+ " return image / 255.0\n",
+ "\n",
+ " def decode_label(label):\n",
+ " label = tf.decode_raw(label, tf.uint8)\n",
+ " label = tf.reshape(label, [])\n",
+ " return tf.to_int32(label)\n",
+ "\n",
+ " images = tf.data.FixedLengthRecordDataset(\n",
+ " images_file, 28 * 28, header_bytes=16).map(decode_image)\n",
+ " labels = tf.data.FixedLengthRecordDataset(\n",
+ " labels_file, 1, header_bytes=8).map(decode_label)\n",
+ " return tf.data.Dataset.zip((images, labels))\n",
+ "\n",
+ "\n",
+ "def mnist_train(directory):\n",
+ " return dataset(directory, 'train-images-idx3-ubyte',\n",
+ " 'train-labels-idx1-ubyte')\n",
+ "\n",
+ "def mnist_test(directory):\n",
+ " return dataset(directory, 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "znmy4l8ntMvW"
+ },
+ "source": [
+ "#### Define the model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "Pe-erWQdBoC5"
+ },
+ "outputs": [],
+ "source": [
+ "def mlp_model(input_shape):\n",
+ " model = tf.keras.Sequential((\n",
+ " tf.keras.layers.Dense(100, activation='relu', input_shape=input_shape),\n",
+ " tf.keras.layers.Dense(100, activation='relu'),\n",
+ " tf.keras.layers.Dense(10, activation='softmax')))\n",
+ " model.build()\n",
+ " return model\n",
+ "\n",
+ "\n",
+ "def predict(m, x, y):\n",
+ " y_p = m(x)\n",
+ " losses = tf.keras.losses.categorical_crossentropy(y, y_p)\n",
+ " l = tf.reduce_mean(losses)\n",
+ " accuracies = tf.keras.metrics.categorical_accuracy(y, y_p)\n",
+ " accuracy = tf.reduce_mean(accuracies)\n",
+ " return l, accuracy\n",
+ "\n",
+ "\n",
+ "def fit(m, x, y, opt):\n",
+ " l, accuracy = predict(m, x, y)\n",
+ " opt.minimize(l)\n",
+ " return l, accuracy\n",
+ "\n",
+ "\n",
+ "def setup_mnist_data(is_training, hp, batch_size):\n",
+ " if is_training:\n",
+ " ds = mnist_train('/tmp/autograph_mnist_data')\n",
+ " ds = ds.shuffle(batch_size * 10)\n",
+ " else:\n",
+ " ds = mnist_test('/tmp/autograph_mnist_data')\n",
+ " ds = ds.repeat()\n",
+ " ds = ds.batch(batch_size)\n",
+ " return ds\n",
+ "\n",
+ "\n",
+ "def get_next_batch(ds):\n",
+ " itr = ds.make_one_shot_iterator()\n",
+ " image, label = itr.get_next()\n",
+ " x = tf.to_float(tf.reshape(image, (-1, 28 * 28)))\n",
+ " y = tf.one_hot(tf.squeeze(label), 10)\n",
+ " return x, y"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "oeYV6mKnJGMr"
+ },
+ "source": [
+ "#### Define the training loop"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "3xtg_MMhJETd"
+ },
+ "outputs": [],
+ "source": [
+ "def train(train_ds, test_ds, hp):\n",
+ " m = mlp_model((28 * 28,))\n",
+ " opt = tf.train.MomentumOptimizer(hp.learning_rate, 0.9)\n",
+ " \n",
+ " # We'd like to save our losses to a list. In order for AutoGraph\n",
+ " # to convert these lists into their graph equivalent,\n",
+ " # we need to specify the element type of the lists.\n",
+ " train_losses = []\n",
+ " train_losses = autograph.utils.set_element_type(train_losses, tf.float32)\n",
+ " test_losses = []\n",
+ " test_losses = autograph.utils.set_element_type(test_losses, tf.float32)\n",
+ " train_accuracies = []\n",
+ " train_accuracies = autograph.utils.set_element_type(train_accuracies, tf.float32)\n",
+ " test_accuracies = []\n",
+ " test_accuracies = autograph.utils.set_element_type(test_accuracies, tf.float32)\n",
+ " \n",
+ " # This entire training loop will be run in-graph.\n",
+ " i = tf.constant(0)\n",
+ " while i \u003c hp.max_steps:\n",
+ " train_x, train_y = get_next_batch(train_ds)\n",
+ " test_x, test_y = get_next_batch(test_ds)\n",
+ " # add get next\n",
+ " step_train_loss, step_train_accuracy = fit(m, train_x, train_y, opt)\n",
+ " step_test_loss, step_test_accuracy = predict(m, test_x, test_y)\n",
+ " if i % (hp.max_steps // 10) == 0:\n",
+ " print('Step', i, 'train loss:', step_train_loss, 'test loss:',\n",
+ " step_test_loss, 'train accuracy:', step_train_accuracy,\n",
+ " 'test accuracy:', step_test_accuracy)\n",
+ " train_losses.append(step_train_loss)\n",
+ " test_losses.append(step_test_loss)\n",
+ " train_accuracies.append(step_train_accuracy)\n",
+ " test_accuracies.append(step_test_accuracy)\n",
+ " i += 1\n",
+ " \n",
+ " # We've recorded our loss values and accuracies \n",
+ " # to a list in a graph with AutoGraph's help.\n",
+ " # In order to return the values as a Tensor, \n",
+ " # we need to stack them before returning them.\n",
+ " return (autograph.stack(train_losses), autograph.stack(test_losses), autograph.stack(train_accuracies),\n",
+ " autograph.stack(test_accuracies))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "HYh6MSZyJOag"
+ },
+ "outputs": [],
+ "source": [
+ "with tf.Graph().as_default():\n",
+ " hp = tf.contrib.training.HParams(\n",
+ " learning_rate=0.05,\n",
+ " max_steps=500,\n",
+ " )\n",
+ " train_ds = setup_mnist_data(True, hp, 50)\n",
+ " test_ds = setup_mnist_data(False, hp, 1000)\n",
+ " tf_train = autograph.to_graph(train)\n",
+ " (train_losses, test_losses, train_accuracies,\n",
+ " test_accuracies) = tf_train(train_ds, test_ds, hp)\n",
+ "\n",
+ " with tf.Session() as sess:\n",
+ " sess.run(tf.global_variables_initializer())\n",
+ " (train_losses, test_losses, train_accuracies,\n",
+ " test_accuracies) = sess.run([train_losses, test_losses, train_accuracies,\n",
+ " test_accuracies])\n",
+ " plt.title('MNIST train/test losses')\n",
+ " plt.plot(train_losses, label='train loss')\n",
+ " plt.plot(test_losses, label='test loss')\n",
+ " plt.legend()\n",
+ " plt.xlabel('Training step')\n",
+ " plt.ylabel('Loss')\n",
+ " plt.show()\n",
+ " plt.title('MNIST train/test accuracies')\n",
+ " plt.plot(train_accuracies, label='train accuracy')\n",
+ " plt.plot(test_accuracies, label='test accuracy')\n",
+ " plt.legend(loc='lower right')\n",
+ " plt.xlabel('Training step')\n",
+ " plt.ylabel('Accuracy')\n",
+ " plt.show()"
+ ]
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "collapsed_sections": [
+ "qqsjik-QyA9R",
+ "b9AXIkNLxp6J",
+ "L2psuzPI02S9",
+ "weKFXAb615Vp",
+ "Em5dzSUOtLRP"
+ ],
+ "default_view": {},
+ "name": "AutoGraph Workshop.ipynb",
+ "provenance": [
+ {
+ "file_id": "1kE2gz_zuwdYySL4K2HQSz13uLCYi-fYP",
+ "timestamp": 1530563781803
+ }
+ ],
+ "version": "0.3.2",
+ "views": {}
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/tensorflow/contrib/autograph/impl/api.py b/tensorflow/contrib/autograph/impl/api.py
index c7401c7df1..f7fe3de5da 100644
--- a/tensorflow/contrib/autograph/impl/api.py
+++ b/tensorflow/contrib/autograph/impl/api.py
@@ -99,6 +99,7 @@ def do_not_convert(run_as=RunMode.GRAPH, return_dtypes=None):
Returns:
A decorator that wraps the original function.
"""
+
def decorator(f):
"""Decorator implementation."""
@@ -109,8 +110,7 @@ def do_not_convert(run_as=RunMode.GRAPH, return_dtypes=None):
@wraps(f)
def py_func_wrapper(*args, **kwargs):
if kwargs:
- raise NotImplementedError(
- 'RunMode.PY_FUNC does not yet support kwargs')
+ raise NotImplementedError('RunMode.PY_FUNC does not yet support kwargs')
# TODO(mdan): Add support for kwargs.
return py_func.wrap_py_func(
f, return_dtypes, args, kwargs, use_dummy_return=not return_dtypes)
@@ -231,7 +231,10 @@ def to_graph(e,
Returns:
A function with a signature identical to `o`, but which when executed it
- creates TF a graph that has the same functionality as the original entity.
+ creates TF a graph that has the same functionality as the original entity.
+ Raises:
+ ValueError: If the converted function defines or refers to symbol names that
+ are reserved for AutoGraph.
"""
program_ctx = converter.ProgramContext(
recursive=recursive,
@@ -256,6 +259,19 @@ def to_graph(e,
compiled_node.__dict__[key] = val
compiled_fn = getattr(compiled_node, name)
+ # Need this so the source_mapping attribute is available for the context
+ # manager to access for runtime errors.
+ #
+ # Note that compiler.ast_to_object attaches the source map 'ag_source_map__'
+ # symbol to the compiled module.
+ source_map_attribute_name = 'ag_source_map'
+ if getattr(compiled_fn, source_map_attribute_name, None) is not None:
+ raise ValueError('cannot convert %s because is has an attribute '
+ '"%s", which is reserved for AutoGraph.' %
+ (compiled_fn, source_map_attribute_name))
+ setattr(compiled_fn, source_map_attribute_name,
+ compiled_node.__dict__['ag_source_map__'])
+
if verbose:
logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src)
@@ -292,7 +308,7 @@ def to_code(e,
conversion.entity_to_graph(e, program_ctx, arg_values, arg_types)
code = '\n'.join(
- compiler.ast_to_source(dep, indentation)
+ compiler.ast_to_source(dep, indentation)[0]
for dep in reversed(tuple(six.itervalues(program_ctx.dependency_cache))))
return program_ctx.required_imports + '\n\n' + code
diff --git a/tensorflow/contrib/autograph/impl/api_test.py b/tensorflow/contrib/autograph/impl/api_test.py
index 9943093332..4de7df6572 100644
--- a/tensorflow/contrib/autograph/impl/api_test.py
+++ b/tensorflow/contrib/autograph/impl/api_test.py
@@ -206,8 +206,8 @@ class ApiTest(test.TestCase):
return x
with self.test_session() as sess:
- x = api.converted_call(
- test_fn, False, False, {}, constant_op.constant(-1))
+ x = api.converted_call(test_fn, False, False, {},
+ constant_op.constant(-1))
self.assertEqual(1, sess.run(x))
def test_converted_call_method(self):
@@ -274,8 +274,8 @@ class ApiTest(test.TestCase):
return self.x
with self.test_session() as sess:
- tc = api.converted_call(
- TestClass, False, False, {}, constant_op.constant(-1))
+ tc = api.converted_call(TestClass, False, False, {},
+ constant_op.constant(-1))
# tc is now a converted object.
x = tc.test_method()
self.assertEqual(1, sess.run(x))
@@ -305,6 +305,13 @@ class ApiTest(test.TestCase):
# Just check that it is parseable Python code.
self.assertIsNotNone(parser.parse_str(compiled_code))
+ def test_source_map_attribute_present(self):
+
+ def test_fn(y):
+ return y**2
+
+ self.assertTrue(hasattr(api.to_graph(test_fn), 'ag_source_map'))
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/autograph/impl/conversion.py b/tensorflow/contrib/autograph/impl/conversion.py
index 776d19f672..bd14359356 100644
--- a/tensorflow/contrib/autograph/impl/conversion.py
+++ b/tensorflow/contrib/autograph/impl/conversion.py
@@ -31,6 +31,7 @@ from tensorflow.contrib.autograph.converters import call_trees
from tensorflow.contrib.autograph.converters import continue_statements
from tensorflow.contrib.autograph.converters import control_flow
from tensorflow.contrib.autograph.converters import decorators
+from tensorflow.contrib.autograph.converters import error_handlers
from tensorflow.contrib.autograph.converters import ifexp
from tensorflow.contrib.autograph.converters import lists
from tensorflow.contrib.autograph.converters import logical_expressions
@@ -40,8 +41,10 @@ from tensorflow.contrib.autograph.converters import single_return
from tensorflow.contrib.autograph.converters import slices
from tensorflow.contrib.autograph.core import config
from tensorflow.contrib.autograph.core import converter
+from tensorflow.contrib.autograph.core import errors
from tensorflow.contrib.autograph.pyct import ast_util
from tensorflow.contrib.autograph.pyct import inspect_utils
+from tensorflow.contrib.autograph.pyct import origin_info
from tensorflow.contrib.autograph.pyct import parser
from tensorflow.contrib.autograph.pyct import qual_names
from tensorflow.contrib.autograph.pyct import transformer
@@ -231,6 +234,8 @@ def _add_self_references(namespace, autograph_module):
ag_internal = imp.new_module('autograph')
ag_internal.converted_call = autograph_module.converted_call
ag_internal.utils = utils
+ ag_internal.rewrite_graph_construction_error = (
+ errors.rewrite_graph_construction_error)
# TODO(mdan): Add safeguards against name clashes.
# We don't want to create a submodule because we want the operators to be
# accessible as ag__.<operator>
@@ -241,9 +246,10 @@ def _add_self_references(namespace, autograph_module):
def function_to_graph(f, program_ctx, arg_values, arg_types, owner_type=None):
"""Specialization of `entity_to_graph` for callable functions."""
+
node, source = parser.parse_entity(f)
node = node.body[0]
-
+ origin_info.resolve(node, source, f)
namespace = inspect_utils.getnamespace(f)
_add_self_references(namespace, program_ctx.autograph_module)
namer = program_ctx.new_namer(namespace)
@@ -319,4 +325,5 @@ def node_to_graph(node, context):
node = _apply_transformer(node, context, logical_expressions)
node = _apply_transformer(node, context, side_effect_guards)
node = _apply_transformer(node, context, name_scopes)
+ node = _apply_transformer(node, context, error_handlers)
return node
diff --git a/tensorflow/contrib/autograph/impl/conversion_test.py b/tensorflow/contrib/autograph/impl/conversion_test.py
index f5279298af..207225a1ac 100644
--- a/tensorflow/contrib/autograph/impl/conversion_test.py
+++ b/tensorflow/contrib/autograph/impl/conversion_test.py
@@ -79,10 +79,12 @@ class ConversionTest(test.TestCase):
self.assertTrue(f in program_ctx.dependency_cache)
self.assertTrue(g in program_ctx.dependency_cache)
self.assertEqual('tf__f', program_ctx.dependency_cache[f].name)
- # need the extra .body[0] in order to step past the with tf.name_scope('f')
- # that is added automatically
+ # need one extra .body[0] in order to step past the try/except wrapper that
+ # is added automatically, the other for the with tf.name_scope('f') that is
+ # added automatically
self.assertEqual(
- 'tf__g', program_ctx.dependency_cache[f].body[0].body[0].value.func.id)
+ 'tf__g',
+ program_ctx.dependency_cache[f].body[0].body[0].body[0].value.func.id)
self.assertEqual('tf__g', program_ctx.dependency_cache[g].name)
def test_entity_to_graph_class_hierarchy(self):
diff --git a/tensorflow/contrib/autograph/operators/__init__.py b/tensorflow/contrib/autograph/operators/__init__.py
index c900fd6af2..392cb60bcc 100644
--- a/tensorflow/contrib/autograph/operators/__init__.py
+++ b/tensorflow/contrib/autograph/operators/__init__.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""This module implements operators that we overload.
+"""This module implements operators that AutoGraph overloads.
Note that "operator" is used loosely here, and includes control structures like
conditionals and loops, implemented in functional form, using for example
diff --git a/tensorflow/contrib/autograph/pyct/BUILD b/tensorflow/contrib/autograph/pyct/BUILD
index a49a4ed05c..f77a6ab392 100644
--- a/tensorflow/contrib/autograph/pyct/BUILD
+++ b/tensorflow/contrib/autograph/pyct/BUILD
@@ -25,6 +25,7 @@ py_library(
"cfg.py",
"compiler.py",
"inspect_utils.py",
+ "origin_info.py",
"parser.py",
"pretty_printer.py",
"qual_names.py",
diff --git a/tensorflow/contrib/autograph/pyct/anno.py b/tensorflow/contrib/autograph/pyct/anno.py
index ae861627fd..1a52110ef3 100644
--- a/tensorflow/contrib/autograph/pyct/anno.py
+++ b/tensorflow/contrib/autograph/pyct/anno.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Handling annotations on AST nodes.
+"""AST node annotation support.
Adapted from Tangent.
"""
@@ -21,37 +21,90 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from enum import Enum
+import enum
+# pylint:disable=g-bad-import-order
+import gast
+# pylint:enable=g-bad-import-order
-class NoValue(Enum):
+
+# TODO(mdan): Shorten the names.
+# These names are heavily used, and anno.blaa
+# TODO(mdan): Replace the attr-dict mechanism with a more typed solution.
+
+
+class NoValue(enum.Enum):
def __repr__(self):
return self.name
class Basic(NoValue):
- """Container for annotation keys.
+ """Container for basic annotation keys.
The enum values are used strictly for documentation purposes.
"""
- QN = 'Qualified name, as it appeared in the code.'
+ QN = 'Qualified name, as it appeared in the code. See qual_names.py.'
SKIP_PROCESSING = (
'This node should be preserved as is and not processed any further.')
INDENT_BLOCK_REMAINDER = (
- 'When a node is annotated with this, the remainder of the block should '
- 'be indented below it. The annotation contains a tuple '
- '(new_body, name_map), where `new_body` is the new indented block and '
- '`name_map` allows renaming symbols.')
+ 'When a node is annotated with this, the remainder of the block should'
+ ' be indented below it. The annotation contains a tuple'
+ ' (new_body, name_map), where `new_body` is the new indented block and'
+ ' `name_map` allows renaming symbols.')
+ ORIGIN = ('Information about the source code that converted code originated'
+ ' from. See origin_information.py.')
+
+
+class Static(NoValue):
+ """Container for static analysis annotation keys.
+
+ The enum values are used strictly for documentation purposes.
+ """
+
+ # Deprecated - use reaching definitions instead.
+ # Symbols
+ # These flags are boolean.
+ IS_LOCAL = 'Symbol is local to the function scope being analyzed.'
+ IS_PARAM = 'Symbol is a parameter to the function being analyzed.'
+
+ # Scopes
+ # Scopes are represented by objects of type activity.Scope.
+ SCOPE = 'The scope for the annotated node. See activity.py.'
+ # TODO(mdan): Drop these in favor of accessing the child's SCOPE.
+ ARGS_SCOPE = 'The scope for the argument list of a function call.'
+ COND_SCOPE = 'The scope for the test node of a conditional statement.'
+ BODY_SCOPE = (
+ 'The scope for the main body of a statement (True branch for if '
+ 'statements, main body for loops).')
+ ORELSE_SCOPE = (
+ 'The scope for the orelse body of a statement (False branch for if '
+ 'statements, orelse body for loops).')
+
+ # Static analysis annotations.
+ DEFINITIONS = (
+ 'Reaching definition information. See reaching_definitions.py.')
+ ORIG_DEFINITIONS = (
+ 'The value of DEFINITIONS that applied to the original code before any'
+ ' conversion.')
+ DEFINED_VARS_IN = (
+ 'Symbols defined when entering the node. See reaching_definitions.py.')
+ LIVE_VARS_OUT = ('Symbols live when exiting the node. See liveness.py.')
FAIL = object()
+def keys(node, field_name='___pyct_anno'):
+ if not hasattr(node, field_name):
+ return frozenset()
+ return frozenset(getattr(node, field_name).keys())
+
+
def getanno(node, key, default=FAIL, field_name='___pyct_anno'):
- if (default is FAIL or
- (hasattr(node, field_name) and (key in getattr(node, field_name)))):
+ if (default is FAIL or (hasattr(node, field_name) and
+ (key in getattr(node, field_name)))):
return getattr(node, field_name)[key]
else:
return default
@@ -86,3 +139,19 @@ def copyanno(from_node, to_node, key, field_name='___pyct_anno'):
key,
getanno(from_node, key, field_name=field_name),
field_name=field_name)
+
+
+def dup(node, copy_map, field_name='___pyct_anno'):
+ """Recursively copies annotations in an AST tree.
+
+ Args:
+ node: ast.AST
+ copy_map: Dict[Hashable, Hashable], maps a source anno key to a destination
+ key. All annotations with the source key will be copied to identical
+ annotations with the destination key.
+ field_name: str
+ """
+ for n in gast.walk(node):
+ for k in copy_map:
+ if hasanno(n, k, field_name):
+ setanno(n, copy_map[k], getanno(n, k, field_name), field_name)
diff --git a/tensorflow/contrib/autograph/pyct/anno_test.py b/tensorflow/contrib/autograph/pyct/anno_test.py
index f2c0c8cf05..5ef4da61a3 100644
--- a/tensorflow/contrib/autograph/pyct/anno_test.py
+++ b/tensorflow/contrib/autograph/pyct/anno_test.py
@@ -32,22 +32,27 @@ class AnnoTest(test.TestCase):
def test_basic(self):
node = ast.Name()
+ self.assertEqual(anno.keys(node), set())
self.assertFalse(anno.hasanno(node, 'foo'))
with self.assertRaises(AttributeError):
anno.getanno(node, 'foo')
anno.setanno(node, 'foo', 3)
+
+ self.assertEqual(anno.keys(node), {'foo'})
self.assertTrue(anno.hasanno(node, 'foo'))
self.assertEqual(anno.getanno(node, 'foo'), 3)
self.assertEqual(anno.getanno(node, 'bar', default=7), 7)
anno.delanno(node, 'foo')
+
+ self.assertEqual(anno.keys(node), set())
self.assertFalse(anno.hasanno(node, 'foo'))
with self.assertRaises(AttributeError):
anno.getanno(node, 'foo')
self.assertIsNone(anno.getanno(node, 'foo', default=None))
- def test_copyanno(self):
+ def test_copy(self):
node_1 = ast.Name()
anno.setanno(node_1, 'foo', 3)
@@ -58,6 +63,22 @@ class AnnoTest(test.TestCase):
self.assertTrue(anno.hasanno(node_2, 'foo'))
self.assertFalse(anno.hasanno(node_2, 'bar'))
+ def test_duplicate(self):
+ node = ast.If(
+ test=ast.Num(1),
+ body=[ast.Expr(ast.Name('bar', ast.Load()))],
+ orelse=[])
+ anno.setanno(node, 'spam', 1)
+ anno.setanno(node, 'ham', 1)
+ anno.setanno(node.body[0], 'ham', 1)
+
+ anno.dup(node, {'spam': 'eggs'})
+
+ self.assertTrue(anno.hasanno(node, 'spam'))
+ self.assertTrue(anno.hasanno(node, 'ham'))
+ self.assertTrue(anno.hasanno(node, 'eggs'))
+ self.assertFalse(anno.hasanno(node.body[0], 'eggs'))
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/autograph/pyct/ast_util.py b/tensorflow/contrib/autograph/pyct/ast_util.py
index c4f82d1170..86e3f56a64 100644
--- a/tensorflow/contrib/autograph/pyct/ast_util.py
+++ b/tensorflow/contrib/autograph/pyct/ast_util.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Copy an AST tree, discarding annotations."""
+"""AST manipulation utilities."""
from __future__ import absolute_import
from __future__ import division
@@ -20,53 +20,60 @@ from __future__ import print_function
import ast
+import collections
import gast
from tensorflow.contrib.autograph.pyct import anno
from tensorflow.contrib.autograph.pyct import parser
-class CleanCopier(gast.NodeVisitor):
- """Copies AST nodes.
+class CleanCopier(object):
+ """NodeTransformer-like visitor that copies an AST."""
- The copied nodes will ignore almost all fields that are prefixed by '__'.
- Exceptions make some annotations.
- """
+ def __init__(self, preserve_annos):
+ super(CleanCopier, self).__init__()
+ self.preserve_annos = preserve_annos
- # TODO(mdan): Parametrize which annotations get carried over.
+ def copy(self, node):
+ """Returns a deep copy of node (excluding some fields, see copy_clean)."""
+
+ if isinstance(node, list):
+ return [self.copy(n) for n in node]
+ elif isinstance(node, tuple):
+ return tuple(self.copy(n) for n in node)
+ elif not isinstance(node, (gast.AST, ast.AST)):
+ # Assuming everything that's not an AST, list or tuple is a value type
+ # and may simply be assigned.
+ return node
+
+ assert isinstance(node, (gast.AST, ast.AST))
- def generic_visit(self, node):
new_fields = {}
for f in node._fields:
- if f.startswith('__'):
- continue
- if not hasattr(node, f):
- continue
- v = getattr(node, f)
- if isinstance(v, list):
- v = [self.generic_visit(n) for n in v]
- elif isinstance(v, tuple):
- v = tuple(self.generic_visit(n) for n in v)
- elif isinstance(v, (gast.AST, ast.AST)):
- v = self.generic_visit(v)
- else:
- # Assume everything else is a value type.
- pass
- new_fields[f] = v
+ if not f.startswith('__') and hasattr(node, f):
+ new_fields[f] = self.copy(getattr(node, f))
new_node = type(node)(**new_fields)
- if anno.hasanno(node, anno.Basic.SKIP_PROCESSING):
- anno.setanno(new_node, anno.Basic.SKIP_PROCESSING, True)
+
+ if self.preserve_annos:
+ for k in self.preserve_annos:
+ anno.copyanno(node, new_node, k)
return new_node
-def copy_clean(node):
- copier = CleanCopier()
- if isinstance(node, list):
- return [copier.visit(n) for n in node]
- elif isinstance(node, tuple):
- return tuple(copier.visit(n) for n in node)
- else:
- return copier.visit(node)
+def copy_clean(node, preserve_annos=None):
+ """Creates a deep copy of an AST.
+
+ The copy will not include fields that are prefixed by '__', with the
+ exception of user-specified annotations.
+
+ Args:
+ node: ast.AST
+ preserve_annos: Optional[Set[Hashable]], annotation keys to include in the
+ copy
+ Returns:
+ ast.AST
+ """
+ return CleanCopier(preserve_annos).copy(node)
class SymbolRenamer(gast.NodeTransformer):
@@ -78,7 +85,11 @@ class SymbolRenamer(gast.NodeTransformer):
def _process(self, node):
qn = anno.getanno(node, anno.Basic.QN)
if qn in self.name_map:
- return gast.Name(str(self.name_map[qn]), node.ctx, None)
+ new_node = gast.Name(str(self.name_map[qn]), node.ctx, None)
+ # All annotations get carried over.
+ for k in anno.keys(node):
+ anno.copyanno(node, new_node, k)
+ return new_node
return self.generic_visit(node)
def visit_Name(self, node):
@@ -92,6 +103,7 @@ class SymbolRenamer(gast.NodeTransformer):
def rename_symbols(node, name_map):
+ """Renames symbols in an AST. Requires qual_names annotations."""
renamer = SymbolRenamer(name_map)
if isinstance(node, list):
return [renamer.visit(n) for n in node]
@@ -101,6 +113,7 @@ def rename_symbols(node, name_map):
def keywords_to_dict(keywords):
+ """Converts a list of ast.keyword objects to a dict."""
keys = []
values = []
for kw in keywords:
@@ -110,10 +123,7 @@ def keywords_to_dict(keywords):
class PatternMatcher(gast.NodeVisitor):
- """Matches a node against a pattern represented by a node.
-
- The pattern may contain wildcards represented by the symbol '_'.
- """
+ """Matches a node against a pattern represented by a node."""
def __init__(self, pattern):
self.pattern = pattern
@@ -175,11 +185,98 @@ class PatternMatcher(gast.NodeVisitor):
if v != p:
return self.no_match()
-
def matches(node, pattern):
+ """Basic pattern matcher for AST.
+
+ The pattern may contain wildcards represented by the symbol '_'. A node
+ matches a pattern if for every node in the tree, either there is a node of
+ the same type in pattern, or a Name node with id='_'.
+
+ Args:
+ node: ast.AST
+ pattern: ast.AST
+ Returns:
+ bool
+ """
if isinstance(pattern, str):
pattern = parser.parse_expression(pattern)
matcher = PatternMatcher(pattern)
matcher.visit(node)
return matcher.matches
+
+# TODO(mdan): Once we have error tracing, we may be able to just go to SSA.
+def apply_to_single_assignments(targets, values, apply_fn):
+ """Applies a function to each individual assignment.
+
+ This function can process a possibly-unpacked (e.g. a, b = c, d) assignment.
+ It tries to break down the unpacking if possible. In effect, it has the same
+ effect as passing the assigned values in SSA form to apply_fn.
+
+ Examples:
+
+ The following will result in apply_fn(a, c), apply_fn(b, d):
+
+ a, b = c, d
+
+ The following will result in apply_fn(a, c[0]), apply_fn(b, c[1]):
+
+ a, b = c
+
+ The following will result in apply_fn(a, (b, c)):
+
+ a = b, c
+
+ It uses the visitor pattern to allow subclasses to process single
+ assignments individually.
+
+ Args:
+ targets: Union[List[ast.AST, ...], Tuple[ast.AST, ...], ast.AST, should be
+ used with the targets field of an ast.Assign node
+ values: ast.AST
+ apply_fn: Callable[[ast.AST, ast.AST], None], called with the
+ respective nodes of each single assignment
+ """
+ if not isinstance(targets, (list, tuple)):
+ targets = (targets,)
+ for target in targets:
+ if isinstance(target, (gast.Tuple, gast.List)):
+ for i in range(len(target.elts)):
+ target_el = target.elts[i]
+ if isinstance(values, (gast.Tuple, gast.List)):
+ value_el = values.elts[i]
+ else:
+ idx = parser.parse_expression(str(i))
+ value_el = gast.Subscript(values, gast.Index(idx), ctx=gast.Load())
+ apply_to_single_assignments(target_el, value_el, apply_fn)
+ else:
+ apply_fn(target, values)
+
+
+def iter_fields(node):
+ for field in sorted(node._fields):
+ try:
+ yield getattr(node, field)
+ except AttributeError:
+ pass
+
+
+def iter_child_nodes(node):
+ for field in iter_fields(node):
+ if isinstance(field, gast.AST):
+ yield field
+ elif isinstance(field, list):
+ for item in field:
+ if isinstance(item, gast.AST):
+ yield item
+
+
+def parallel_walk(node_a, node_b):
+ todo_a = collections.deque([node_a])
+ todo_b = collections.deque([node_b])
+ while todo_a and todo_b:
+ node_a = todo_a.popleft()
+ node_b = todo_b.popleft()
+ todo_a.extend(iter_child_nodes(node_a))
+ todo_b.extend(iter_child_nodes(node_b))
+ yield node_a, node_b
diff --git a/tensorflow/contrib/autograph/pyct/ast_util_test.py b/tensorflow/contrib/autograph/pyct/ast_util_test.py
index 3afa04a506..981e398b93 100644
--- a/tensorflow/contrib/autograph/pyct/ast_util_test.py
+++ b/tensorflow/contrib/autograph/pyct/ast_util_test.py
@@ -19,7 +19,10 @@ from __future__ import division
from __future__ import print_function
import ast
+import collections
+import textwrap
+from tensorflow.contrib.autograph.pyct import anno
from tensorflow.contrib.autograph.pyct import ast_util
from tensorflow.contrib.autograph.pyct import compiler
from tensorflow.contrib.autograph.pyct import parser
@@ -29,53 +32,66 @@ from tensorflow.python.platform import test
class AstUtilTest(test.TestCase):
- def test_rename_symbols(self):
- node = ast.Tuple([
- ast.Name('a', ast.Load()),
- ast.Name('b', ast.Load()),
- ast.Attribute(ast.Name('b', None), 'c', ast.Store()),
- ast.Attribute(
- ast.Attribute(ast.Name('b', None), 'c', ast.Load()), 'd', None)
- ], None)
+ def setUp(self):
+ super(AstUtilTest, self).setUp()
+ self._invocation_counts = collections.defaultdict(lambda: 0)
+
+ def test_rename_symbols_basic(self):
+ node = parser.parse_str('a + b')
+ node = qual_names.resolve(node)
+
+ node = ast_util.rename_symbols(
+ node, {qual_names.QN('a'): qual_names.QN('renamed_a')})
+
+ self.assertIsInstance(node.body[0].value.left.id, str)
+ source, _ = compiler.ast_to_source(node)
+ self.assertEqual(source.strip(), 'renamed_a + b')
+
+ def test_rename_symbols_attributes(self):
+ node = parser.parse_str('b.c = b.c.d')
node = qual_names.resolve(node)
+
node = ast_util.rename_symbols(
- node, {
- qual_names.QN('a'):
- qual_names.QN('renamed_a'),
- qual_names.QN(qual_names.QN('b'), attr='c'):
- qual_names.QN('renamed_b_c'),
- })
-
- self.assertEqual(node.elts[0].id, 'renamed_a')
- self.assertTrue(isinstance(node.elts[0].ctx, ast.Load))
- self.assertEqual(node.elts[1].id, 'b')
- self.assertEqual(node.elts[2].id, 'renamed_b_c')
- self.assertTrue(isinstance(node.elts[2].ctx, ast.Store))
- self.assertEqual(node.elts[3].value.id, 'renamed_b_c')
- self.assertTrue(isinstance(node.elts[3].value.ctx, ast.Load))
+ node, {qual_names.from_str('b.c'): qual_names.QN('renamed_b_c')})
+
+ source, _ = compiler.ast_to_source(node)
+ self.assertEqual(source.strip(), 'renamed_b_c = renamed_b_c.d')
+
+ def test_rename_symbols_annotations(self):
+ node = parser.parse_str('a[i]')
+ node = qual_names.resolve(node)
+ anno.setanno(node, 'foo', 'bar')
+ orig_anno = anno.getanno(node, 'foo')
+
+ node = ast_util.rename_symbols(node,
+ {qual_names.QN('a'): qual_names.QN('b')})
+
+ self.assertIs(anno.getanno(node, 'foo'), orig_anno)
def test_copy_clean(self):
- ret = ast.Return(
- ast.BinOp(
- op=ast.Add(),
- left=ast.Name(id='a', ctx=ast.Load()),
- right=ast.Num(1)))
- setattr(ret, '__foo', 'bar')
- node = ast.FunctionDef(
- name='f',
- args=ast.arguments(
- args=[ast.Name(id='a', ctx=ast.Param())],
- vararg=None,
- kwarg=None,
- defaults=[]),
- body=[ret],
- decorator_list=[],
- returns=None)
+ node = parser.parse_str(
+ textwrap.dedent("""
+ def f(a):
+ return a + 1
+ """))
+ setattr(node.body[0], '__foo', 'bar')
new_node = ast_util.copy_clean(node)
- self.assertFalse(node is new_node)
- self.assertFalse(ret is new_node.body[0])
+ self.assertIsNot(new_node, node)
+ self.assertIsNot(new_node.body[0], node.body[0])
self.assertFalse(hasattr(new_node.body[0], '__foo'))
+ def test_copy_clean_preserves_annotations(self):
+ node = parser.parse_str(
+ textwrap.dedent("""
+ def f(a):
+ return a + 1
+ """))
+ anno.setanno(node.body[0], 'foo', 'bar')
+ anno.setanno(node.body[0], 'baz', 1)
+ new_node = ast_util.copy_clean(node, preserve_annos={'foo'})
+ self.assertEqual(anno.getanno(new_node.body[0], 'foo'), 'bar')
+ self.assertFalse(anno.hasanno(new_node.body[0], 'baz'))
+
def test_keywords_to_dict(self):
keywords = parser.parse_expression('f(a=b, c=1, d=\'e\')').keywords
d = ast_util.keywords_to_dict(keywords)
@@ -113,6 +129,52 @@ class AstUtilTest(test.TestCase):
self.assertNoMatch('super(Foo, self).__init__()',
'super(Bar, _).__init__(_)')
+ def _mock_apply_fn(self, target, source):
+ target, _ = compiler.ast_to_source(target)
+ source, _ = compiler.ast_to_source(source)
+ self._invocation_counts[(target.strip(), source.strip())] += 1
+
+ def test_apply_to_single_assignments_dynamic_unpack(self):
+ node = parser.parse_str('a, b, c = d')
+ node = node.body[0]
+ ast_util.apply_to_single_assignments(node.targets, node.value,
+ self._mock_apply_fn)
+ self.assertDictEqual(self._invocation_counts, {
+ ('a', 'd[0]'): 1,
+ ('b', 'd[1]'): 1,
+ ('c', 'd[2]'): 1,
+ })
+
+ def test_apply_to_single_assignments_static_unpack(self):
+ node = parser.parse_str('a, b, c = d, e, f')
+ node = node.body[0]
+ ast_util.apply_to_single_assignments(node.targets, node.value,
+ self._mock_apply_fn)
+ self.assertDictEqual(self._invocation_counts, {
+ ('a', 'd'): 1,
+ ('b', 'e'): 1,
+ ('c', 'f'): 1,
+ })
+
+ def test_parallel_walk(self):
+ ret = ast.Return(
+ ast.BinOp(
+ op=ast.Add(),
+ left=ast.Name(id='a', ctx=ast.Load()),
+ right=ast.Num(1)))
+ node = ast.FunctionDef(
+ name='f',
+ args=ast.arguments(
+ args=[ast.Name(id='a', ctx=ast.Param())],
+ vararg=None,
+ kwarg=None,
+ defaults=[]),
+ body=[ret],
+ decorator_list=[],
+ returns=None)
+ for child_a, child_b in ast_util.parallel_walk(node, node):
+ self.assertEqual(child_a, child_b)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/autograph/pyct/cfg.py b/tensorflow/contrib/autograph/pyct/cfg.py
index 666328781f..9f060236f4 100644
--- a/tensorflow/contrib/autograph/pyct/cfg.py
+++ b/tensorflow/contrib/autograph/pyct/cfg.py
@@ -64,11 +64,19 @@ class Node(object):
self.prev = frozenset(self.prev)
def __repr__(self):
- return compiler.ast_to_source(self.ast_node).strip()
+ if isinstance(self.ast_node, gast.FunctionDef):
+ return 'def %s' % self.ast_node.name
+ elif isinstance(self.ast_node, gast.withitem):
+ source, _ = compiler.ast_to_source(self.ast_node.context_expr)
+ return source.strip()
+ source, _ = compiler.ast_to_source(self.ast_node)
+ return source.strip()
class Graph(
- collections.namedtuple('Graph', ['entry', 'exit', 'error', 'index'])):
+ collections.namedtuple(
+ 'Graph',
+ ['entry', 'exit', 'error', 'index', 'stmt_prev', 'stmt_next'])):
"""A Control Flow Graph.
The CFG maintains an index to allow looking up a CFG node by the AST node to
@@ -82,6 +90,11 @@ class Graph(
because these are shared, and wiring them would create a reverse path from
normal control flow into the error nodes, which we want to avoid.
+ The graph also maintains edges corresponding to higher level statements
+ like for-else loops. A node is considered successor of a statement if there
+ is an edge from a node that is lexically a child of that statement to a node
+ that is not. Statement predecessors are analogously defined.
+
Attributes:
entry: Node, the entry node
exit: FrozenSet[Node, ...], the exit nodes
@@ -89,6 +102,10 @@ class Graph(
error (errors propagated from function calls are not accounted)
index: Dict[ast.Node, Node], mapping AST nodes to the respective CFG
node
+ stmt_prev: Dict[ast.Node, FrozenSet[Node, ...]], mapping statement AST
+ nodes to their predecessor CFG nodes
+ stmt_next: Dict[ast.Node, FrozenSet[Node, ...]], mapping statement AST
+ nodes to their successor CFG nodes
"""
def __repr__(self):
@@ -96,9 +113,8 @@ class Graph(
for node in self.index.values():
result += ' %s [label="%s"];\n' % (id(node), node)
for node in self.index.values():
- if node.next:
- result += ' %s -> {%s};\n' % (id(node), ', '.join(
- repr(id(n)) for n in node.next))
+ for next_ in node.next:
+ result += ' %s -> %s;\n' % (id(node), id(next_))
result += '}'
return result
@@ -130,25 +146,20 @@ class GraphVisitor(object):
out: Dict[Node, Any], stores node-keyed state during a visit
"""
- def reset(self):
- self.in_ = {
- node: self.init_state(node) for node in self.graph.index.values()
- }
- self.out = {
- node: self.init_state(node) for node in self.graph.index.values()
- }
+ def __init__(self, graph):
+ self.graph = graph
+ self.reset()
def init_state(self, node):
"""State initialization function. Optional to overload.
An in/out state slot will be created for each node in the graph. Subclasses
- may overload this to control what that is initialized to.
+ must overload this to control what that is initialized to.
Args:
node: Node
"""
- del node
- return None
+ raise NotImplementedError('Subclasses must implement this.')
def visit_node(self, node):
"""Visitor function.
@@ -161,6 +172,14 @@ class GraphVisitor(object):
"""
raise NotImplementedError('Subclasses must implement this.')
+ def reset(self):
+ self.in_ = {
+ node: self.init_state(node) for node in self.graph.index.values()
+ }
+ self.out = {
+ node: self.init_state(node) for node in self.graph.index.values()
+ }
+
def _visit_internal(self, mode):
"""Visits the CFG, depth-first."""
assert mode in (_WalkMode.FORWARD, _WalkMode.REVERSE)
@@ -169,7 +188,6 @@ class GraphVisitor(object):
elif mode == _WalkMode.REVERSE:
open_ = list(self.graph.exit)
closed = set()
- self.reset()
while open_:
node = open_.pop(0)
@@ -186,12 +204,10 @@ class GraphVisitor(object):
if should_revisit or next_ not in closed:
open_.append(next_)
- def visit_forward(self, graph):
- self.graph = graph
+ def visit_forward(self):
self._visit_internal(_WalkMode.FORWARD)
- def visit_reverse(self, graph):
- self.graph = graph
+ def visit_reverse(self):
self._visit_internal(_WalkMode.REVERSE)
@@ -244,8 +260,16 @@ class GraphBuilder(object):
# TODO(mdan): Too many primitives. Use classes.
self.leaves = set()
+ # Note: This mechanism requires that nodes are added in lexical order (top
+ # to bottom, depth first).
+ self.active_stmts = set()
+ self.owners = {} # type: Set[any]
+ self.forward_edges = set() # type: Tuple[Node, Node] # (from, to)
+
self.finally_sections = {}
- self.finally_section_subgraphs = {} # Values are [begin_node, exit_nodes]
+ # Dict values represent (entry, exits)
+ self.finally_section_subgraphs = {
+ } # type: Dict[ast.AST, Tuple[Node, Set[Node]]]
# Whether the guard section can be reached from the statement that precedes
# it.
self.finally_section_has_direct_flow = {}
@@ -275,6 +299,7 @@ class GraphBuilder(object):
if isinstance(first, Node):
first.next.add(second)
second.prev.add(first)
+ self.forward_edges.add((first, second))
else:
for node in first:
self._connect_nodes(node, second)
@@ -285,6 +310,7 @@ class GraphBuilder(object):
raise ValueError('%s added twice' % ast_node)
node = Node(next_=set(), prev=set(), ast_node=ast_node)
self.node_index[ast_node] = node
+ self.owners[node] = frozenset(self.active_stmts)
if self.head is None:
self.head = node
@@ -299,6 +325,25 @@ class GraphBuilder(object):
return node
+ def begin_statement(self, stmt):
+ """Marks the beginning of a statement.
+
+ Args:
+ stmt: Hashable, a key by which the statement can be identified in
+ the CFG's stmt_prev and stmt_next attributes
+ """
+ self.active_stmts.add(stmt)
+
+ def end_statement(self, stmt):
+ """Marks the end of a statement.
+
+ Args:
+ stmt: Hashable, a key by which the statement can be identified in
+ the CFG's stmt_prev and stmt_next attributes; must match a key
+ previously passed to begin_statement.
+ """
+ self.active_stmts.remove(stmt)
+
def add_ordinary_node(self, ast_node):
"""Grows the graph by adding an ordinary CFG node.
@@ -505,11 +550,35 @@ class GraphBuilder(object):
for node in self.node_index.values():
node.freeze()
+ # Build the statement edges.
+ stmt_next = {}
+ stmt_prev = {}
+ for node, _ in self.forward_edges:
+ for stmt in self.owners[node]:
+ if stmt not in stmt_next:
+ stmt_next[stmt] = set()
+ if stmt not in stmt_prev:
+ stmt_prev[stmt] = set()
+ for first, second in self.forward_edges:
+ stmts_exited = self.owners[first] - self.owners[second]
+ for stmt in stmts_exited:
+ stmt_next[stmt].add(second)
+ stmts_entered = self.owners[second] - self.owners[first]
+ for stmt in stmts_entered:
+ stmt_prev[stmt].add(first)
+ for stmt in stmt_next:
+ stmt_next[stmt] = frozenset(stmt_next[stmt])
+ for stmt in stmt_prev:
+ stmt_prev[stmt] = frozenset(stmt_prev[stmt])
+
+ # Construct the final graph object.
result = Graph(
entry=self.head,
exit=self.leaves,
error=self.errors,
- index=self.node_index)
+ index=self.node_index,
+ stmt_prev=stmt_prev,
+ stmt_next=stmt_next)
# Reset the state.
self.reset()
@@ -523,8 +592,6 @@ class AstToCfg(gast.NodeVisitor):
A separate CFG will be constructed for each function.
"""
- # TODO(mdan): Figure out how to deal with closures.
-
def __init__(self):
super(AstToCfg, self).__init__()
@@ -577,6 +644,13 @@ class AstToCfg(gast.NodeVisitor):
self.builder.add_continue_node(node, try_node, guards)
def visit_FunctionDef(self, node):
+ # We also keep the FunctionDef node in the CFG. This allows us to determine
+ # things like reaching definitions via closure. Note that the function body
+ # will be stored in a separate graph, because function definitions are not
+ # the same as function calls.
+ if self.builder is not None:
+ self.builder.add_ordinary_node(node)
+
self.builder_stack.append(self.builder)
self.builder = GraphBuilder(node)
@@ -637,6 +711,7 @@ class AstToCfg(gast.NodeVisitor):
# targets of jump statements like break/continue/etc. Since there is no
# statement that can interrupt a conditional, we don't need to track their
# lexical scope. That may change in the future.
+ self.builder.begin_statement(node)
self.builder.enter_cond_section(node)
self._process_basic_statement(node.test)
@@ -650,8 +725,10 @@ class AstToCfg(gast.NodeVisitor):
self.visit(stmt)
self.builder.exit_cond_section(node)
+ self.builder.end_statement(node)
def visit_While(self, node):
+ self.builder.begin_statement(node)
self._enter_lexical_scope(node)
self.builder.enter_section(node)
@@ -670,8 +747,10 @@ class AstToCfg(gast.NodeVisitor):
self.visit(stmt)
self.builder.exit_section(node)
+ self.builder.end_statement(node)
def visit_For(self, node):
+ self.builder.begin_statement(node)
self._enter_lexical_scope(node)
self.builder.enter_section(node)
@@ -693,6 +772,7 @@ class AstToCfg(gast.NodeVisitor):
self.visit(stmt)
self.builder.exit_section(node)
+ self.builder.end_statement(node)
def visit_Break(self, node):
self._process_exit_statement(node, gast.While, gast.For)
@@ -722,12 +802,13 @@ class AstToCfg(gast.NodeVisitor):
def visit_With(self, node):
# TODO(mdan): Mark the context manager's exit call as exit guard.
- self._process_basic_statement(node.items)
+ for item in node.items:
+ self._process_basic_statement(item)
for stmt in node.body:
self.visit(stmt)
def build(node):
- builder = AstToCfg()
- builder.visit(node)
- return builder.cfgs
+ visitor = AstToCfg()
+ visitor.visit(node)
+ return visitor.cfgs
diff --git a/tensorflow/contrib/autograph/pyct/cfg_test.py b/tensorflow/contrib/autograph/pyct/cfg_test.py
index 00afadd521..9d0a85d615 100644
--- a/tensorflow/contrib/autograph/pyct/cfg_test.py
+++ b/tensorflow/contrib/autograph/pyct/cfg_test.py
@@ -25,9 +25,13 @@ from tensorflow.python.platform import test
class CountingVisitor(cfg.GraphVisitor):
- def __init__(self):
+ def __init__(self, graph):
+ super(CountingVisitor, self).__init__(graph)
self.counts = {}
+ def init_state(self, _):
+ return None
+
def visit_node(self, node):
self.counts[node.ast_node] = self.counts.get(node.ast_node, 0) + 1
return False # visit only once
@@ -51,8 +55,8 @@ class GraphVisitorTest(test.TestCase):
graphs, node = self._build_cfg(test_fn)
graph, = graphs.values()
- visitor = CountingVisitor()
- visitor.visit_forward(graph)
+ visitor = CountingVisitor(graph)
+ visitor.visit_forward()
fn_node = node.body[0]
self.assertEqual(visitor.counts[fn_node.args], 1)
@@ -74,8 +78,8 @@ class GraphVisitorTest(test.TestCase):
graphs, node = self._build_cfg(test_fn)
graph, = graphs.values()
- visitor = CountingVisitor()
- visitor.visit_reverse(graph)
+ visitor = CountingVisitor(graph)
+ visitor.visit_reverse()
fn_node = node.body[0]
self.assertEqual(visitor.counts[fn_node.args], 1)
@@ -94,7 +98,7 @@ class AstToCfgTest(test.TestCase):
return cfgs
def _repr_set(self, node_set):
- return set(repr(n) for n in node_set)
+ return frozenset(repr(n) for n in node_set)
def _as_set(self, elements):
if elements is None:
@@ -110,14 +114,35 @@ class AstToCfgTest(test.TestCase):
matched = False
for cfg_node in graph.index.values():
if repr(cfg_node) == node_repr:
- if (self._as_set(prev) == set(map(repr, cfg_node.prev)) and
- self._as_set(next_) == set(map(repr, cfg_node.next))):
+ if (self._as_set(prev) == frozenset(map(repr, cfg_node.prev)) and
+ self._as_set(next_) == frozenset(map(repr, cfg_node.next))):
matched = True
break
if not matched:
self.fail(
'match failed for node "%s" in graph:\n%s' % (node_repr, graph))
+ def assertStatementEdges(self, graph, edges):
+ """Tests whether the CFG contains the specified statement edges."""
+ for prev_node_reprs, node_repr, next_node_reprs in edges:
+ matched = False
+ partial_matches = []
+ self.assertSetEqual(
+ frozenset(graph.stmt_next.keys()), frozenset(graph.stmt_prev.keys()))
+ for stmt_ast_node in graph.stmt_next:
+ ast_repr = '%s:%s' % (stmt_ast_node.__class__.__name__,
+ stmt_ast_node.lineno)
+ if ast_repr == node_repr:
+ actual_next = frozenset(map(repr, graph.stmt_next[stmt_ast_node]))
+ actual_prev = frozenset(map(repr, graph.stmt_prev[stmt_ast_node]))
+ partial_matches.append((actual_prev, node_repr, actual_next))
+ if (self._as_set(prev_node_reprs) == actual_prev and
+ self._as_set(next_node_reprs) == actual_next):
+ matched = True
+ break
+ if not matched:
+ self.fail('edges mismatch for %s: %s' % (node_repr, partial_matches))
+
def test_straightline(self):
def test_fn(a):
@@ -171,7 +196,7 @@ class AstToCfgTest(test.TestCase):
),
)
- def test_branch_straightline(self):
+ def test_if_straightline(self):
def test_fn(a):
if a > 0:
@@ -189,6 +214,10 @@ class AstToCfgTest(test.TestCase):
('(a > 0)', 'a += -1', None),
),
)
+ self.assertStatementEdges(
+ graph,
+ (('a', 'If:2', None),),
+ )
def test_branch_nested(self):
@@ -219,6 +248,14 @@ class AstToCfgTest(test.TestCase):
('(a > 2)', 'a = 4', None),
),
)
+ self.assertStatementEdges(
+ graph,
+ (
+ ('a', 'If:2', None),
+ ('(a > 0)', 'If:3', None),
+ ('(a > 0)', 'If:8', None),
+ ),
+ )
def test_branch_straightline_semi(self):
@@ -236,6 +273,10 @@ class AstToCfgTest(test.TestCase):
('(a > 0)', 'a = 1', None),
),
)
+ self.assertStatementEdges(
+ graph,
+ (('a', 'If:2', None),),
+ )
def test_branch_return(self):
@@ -257,6 +298,10 @@ class AstToCfgTest(test.TestCase):
('a = 1', 'a = 2', None),
),
)
+ self.assertStatementEdges(
+ graph,
+ (('a', 'If:2', 'a = 2'),),
+ )
def test_branch_return_minimal(self):
@@ -273,6 +318,10 @@ class AstToCfgTest(test.TestCase):
('(a > 0)', 'return', None),
),
)
+ self.assertStatementEdges(
+ graph,
+ (('a', 'If:2', None),),
+ )
def test_while_straightline(self):
@@ -291,6 +340,10 @@ class AstToCfgTest(test.TestCase):
('(a > 0)', 'a = 2', None),
),
)
+ self.assertStatementEdges(
+ graph,
+ (('a', 'While:2', 'a = 2'),),
+ )
def test_while_else_straightline(self):
@@ -312,6 +365,10 @@ class AstToCfgTest(test.TestCase):
('a = 2', 'a = 3', None),
),
)
+ self.assertStatementEdges(
+ graph,
+ (('a', 'While:2', 'a = 3'),),
+ )
def test_while_else_continue(self):
@@ -339,6 +396,13 @@ class AstToCfgTest(test.TestCase):
('a = 2', 'a = 3', None),
),
)
+ self.assertStatementEdges(
+ graph,
+ (
+ ('a', 'While:2', 'a = 3'),
+ ('(a > 0)', 'If:3', ('a = 1', '(a > 0)')),
+ ),
+ )
def test_while_else_break(self):
@@ -364,6 +428,13 @@ class AstToCfgTest(test.TestCase):
(('break', 'a = 2'), 'a = 3', None),
),
)
+ self.assertStatementEdges(
+ graph,
+ (
+ ('a', 'While:2', 'a = 3'),
+ ('(a > 0)', 'If:3', ('a = 1', 'a = 3')),
+ ),
+ )
def test_while_else_return(self):
@@ -389,6 +460,13 @@ class AstToCfgTest(test.TestCase):
('a = 2', 'a = 3', None),
),
)
+ self.assertStatementEdges(
+ graph,
+ (
+ ('a', 'While:2', 'a = 3'),
+ ('(a > 0)', 'If:3', 'a = 1'),
+ ),
+ )
def test_while_nested_straightline(self):
@@ -411,6 +489,13 @@ class AstToCfgTest(test.TestCase):
('(a > 0)', 'a = 3', None),
),
)
+ self.assertStatementEdges(
+ graph,
+ (
+ ('a', 'While:2', 'a = 3'),
+ ('(a > 0)', 'While:3', 'a = 2'),
+ ),
+ )
def test_while_nested_continue(self):
@@ -437,6 +522,14 @@ class AstToCfgTest(test.TestCase):
('(a > 0)', 'a = 3', None),
),
)
+ self.assertStatementEdges(
+ graph,
+ (
+ ('a', 'While:2', 'a = 3'),
+ ('(a > 0)', 'While:3', 'a = 2'),
+ ('(a > 1)', 'If:4', ('a = 1', '(a > 1)')),
+ ),
+ )
def test_while_nested_break(self):
@@ -451,16 +544,21 @@ class AstToCfgTest(test.TestCase):
graph, = self._build_cfg(test_fn).values()
- self.assertGraphMatches(
+ self.assertGraphMatches(graph, (
+ (('a', 'a = 2'), '(a > 0)', ('(a > 1)', 'a = 3')),
+ (('(a > 0)', 'a = 1'), '(a > 1)', ('(a > 2)', 'a = 2')),
+ ('(a > 1)', '(a > 2)', ('break', 'a = 1')),
+ ('(a > 2)', 'break', 'a = 2'),
+ ('(a > 2)', 'a = 1', '(a > 1)'),
+ (('(a > 1)', 'break'), 'a = 2', '(a > 0)'),
+ ('(a > 0)', 'a = 3', None),
+ ))
+ self.assertStatementEdges(
graph,
(
- (('a', 'a = 2'), '(a > 0)', ('(a > 1)', 'a = 3')),
- (('(a > 0)', 'a = 1'), '(a > 1)', ('(a > 2)', 'a = 2')),
- ('(a > 1)', '(a > 2)', ('break', 'a = 1')),
- ('(a > 2)', 'break', 'a = 2'),
- ('(a > 2)', 'a = 1', '(a > 1)'),
- (('(a > 1)', 'break'), 'a = 2', '(a > 0)'),
- ('(a > 0)', 'a = 3', None),
+ ('a', 'While:2', 'a = 3'),
+ ('(a > 0)', 'While:3', 'a = 2'),
+ ('(a > 1)', 'If:4', ('a = 1', 'a = 2')),
),
)
@@ -481,6 +579,10 @@ class AstToCfgTest(test.TestCase):
('range(0, a)', 'a = 2', None),
),
)
+ self.assertStatementEdges(
+ graph,
+ (('a', 'For:2', 'a = 2'),),
+ )
def test_for_else_straightline(self):
@@ -502,6 +604,10 @@ class AstToCfgTest(test.TestCase):
('a = 2', 'a = 3', None),
),
)
+ self.assertStatementEdges(
+ graph,
+ (('a', 'For:2', 'a = 3'),),
+ )
def test_for_else_continue(self):
@@ -530,6 +636,13 @@ class AstToCfgTest(test.TestCase):
('a = 2', 'a = 3', None),
),
)
+ self.assertStatementEdges(
+ graph,
+ (
+ ('a', 'For:2', 'a = 3'),
+ ('range(0, a)', 'If:3', ('a = 1', 'range(0, a)')),
+ ),
+ )
def test_for_else_break(self):
@@ -555,6 +668,13 @@ class AstToCfgTest(test.TestCase):
(('break', 'a = 2'), 'a = 3', None),
),
)
+ self.assertStatementEdges(
+ graph,
+ (
+ ('a', 'For:2', 'a = 3'),
+ ('range(0, a)', 'If:3', ('a = 1', 'a = 3')),
+ ),
+ )
def test_for_else_return(self):
@@ -580,6 +700,13 @@ class AstToCfgTest(test.TestCase):
('a = 2', 'a = 3', None),
),
)
+ self.assertStatementEdges(
+ graph,
+ (
+ ('a', 'For:2', 'a = 3'),
+ ('range(0, a)', 'If:3', 'a = 1'),
+ ),
+ )
def test_for_nested_straightline(self):
@@ -602,6 +729,13 @@ class AstToCfgTest(test.TestCase):
('range(0, a)', 'a = 3', None),
),
)
+ self.assertStatementEdges(
+ graph,
+ (
+ ('a', 'For:2', 'a = 3'),
+ ('range(0, a)', 'For:3', 'a = 2'),
+ ),
+ )
def test_for_nested_continue(self):
@@ -629,6 +763,14 @@ class AstToCfgTest(test.TestCase):
('range(0, a)', 'a = 3', None),
),
)
+ self.assertStatementEdges(
+ graph,
+ (
+ ('a', 'For:2', 'a = 3'),
+ ('range(0, a)', 'For:3', 'a = 2'),
+ ('range(1, a)', 'If:4', ('b += 1', 'range(1, a)')),
+ ),
+ )
def test_for_nested_break(self):
@@ -655,6 +797,14 @@ class AstToCfgTest(test.TestCase):
('range(0, a)', 'a = 3', None),
),
)
+ self.assertStatementEdges(
+ graph,
+ (
+ ('a', 'For:2', 'a = 3'),
+ ('range(0, a)', 'For:3', 'a = 2'),
+ ('range(1, a)', 'If:4', ('b += 1', 'a = 2')),
+ ),
+ )
def test_complex(self):
@@ -704,6 +854,17 @@ class AstToCfgTest(test.TestCase):
('range(1, a)', 'a = 3', None),
),
)
+ self.assertStatementEdges(
+ graph,
+ (
+ ('b = 0', 'While:3', 'range(1, a)'),
+ ('(a > 0)', 'For:4', 'a = 2'),
+ ('range(0, a)', 'If:5', ('(a > 3)', 'a = 2')),
+ ('(a > 2)', 'If:7', ('b += 1', 'a = 2', 'range(0, a)')),
+ ('(a > 3)', 'If:8', ('a = 2', 'range(0, a)')),
+ ('(a > 0)', 'For:17', 'a = 3'),
+ ),
+ )
def test_finally_straightline(self):
@@ -785,6 +946,24 @@ class AstToCfgTest(test.TestCase):
),
)
+ def test_with_straightline(self):
+
+ def test_fn(a):
+ with max(a) as b:
+ a = 0
+ return b
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ ('a', 'max(a)', 'a = 0'),
+ ('max(a)', 'a = 0', 'return b'),
+ ('a = 0', 'return b', None),
+ ),
+ )
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/autograph/pyct/compiler.py b/tensorflow/contrib/autograph/pyct/compiler.py
index 24c4517afa..c172ab21f6 100644
--- a/tensorflow/contrib/autograph/pyct/compiler.py
+++ b/tensorflow/contrib/autograph/pyct/compiler.py
@@ -30,9 +30,49 @@ import tempfile
import astor
import gast
+from tensorflow.contrib.autograph.pyct import anno
+from tensorflow.contrib.autograph.pyct import ast_util
+from tensorflow.contrib.autograph.pyct import origin_info
+from tensorflow.contrib.autograph.pyct import parser
+
+
+def _build_source_map(node, code):
+ """Return the Python objects represented by given AST.
+
+ Compiling the AST code this way ensures that the source code is readable by
+ e.g. `pdb` or `inspect`.
+
+ Args:
+ node: An AST node of the original generated code, before the source code is
+ generated.
+ code: The string representation of the source code for the newly generated
+ code.
+
+ Returns:
+ Dict[CodeLocation, OriginInfo], a mapping between the user and AutoGraph
+ generated code.
+ """
+ # After we have the final generated code we reparse it to get the final line
+ # numbers. Then we walk through the generated and original ASTs in parallel
+ # to build the mapping between the user and generated code.
+ new_node = parser.parse_str(code)
+ origin_info.resolve(new_node, code)
+ source_mapping = {}
+ for before, after in ast_util.parallel_walk(node, new_node):
+ # Need both checks because if origin information is ever copied over to new
+ # nodes then we need to rely on the fact that only the original user code
+ # has the origin annotation.
+ if (anno.hasanno(before, anno.Basic.ORIGIN) and
+ anno.hasanno(after, anno.Basic.ORIGIN)):
+ source_info = anno.getanno(before, anno.Basic.ORIGIN)
+ new_line_number = anno.getanno(after, anno.Basic.ORIGIN).line_number
+ source_mapping[new_line_number] = source_info
+ return source_mapping
+
def ast_to_source(node, indentation=' '):
"""Return the source code of given AST."""
+ original_node = node
if isinstance(node, gast.AST):
node = gast.gast_to_ast(node)
generator = astor.codegen.SourceGenerator(indentation, False,
@@ -42,11 +82,16 @@ def ast_to_source(node, indentation=' '):
# In some versions of Python, literals may appear as actual values. This
# ensures everything is string.
code = map(str, generator.result)
- return astor.source_repr.pretty_source(code).lstrip()
+ code = astor.source_repr.pretty_source(code).lstrip()
+ source_mapping = _build_source_map(original_node, code)
+ return code, source_mapping
-def ast_to_object(
- node, indentation=' ', source_prefix=None, delete_on_exit=True):
+
+def ast_to_object(node,
+ indentation=' ',
+ source_prefix=None,
+ delete_on_exit=True):
"""Return the Python objects represented by given AST.
Compiling the AST code this way ensures that the source code is readable by
@@ -56,15 +101,30 @@ def ast_to_object(
node: The code to compile, as an AST object.
indentation: The string to use for indentation.
source_prefix: Optional string to print as-is into the source file.
- delete_on_exit: Whether to delete the temporary file used for compilation
- on exit.
+ delete_on_exit: Whether to delete the temporary file used for compilation on
+ exit.
Returns:
A module object containing the compiled source code.
+ Raises:
+ ValueError: If ag_source_map__ is already in the namespace of the compiled
+ node.
"""
- source = ast_to_source(node, indentation)
+ # code_source_mapping does not yet include the offsets from import statements.
+ source, code_source_mapping = ast_to_source(node, indentation=indentation)
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
+ # TODO(znado): move into an _offset_source_map() helper function.
+ # Need to offset the generated line numbers by the number of import lines.
+ if source_prefix:
+ num_import_lines = source_prefix.count('\n') + 1
+ else:
+ num_import_lines = 0
+ source_mapping = {}
+ for line_number, original_position in code_source_mapping.items():
+ source_map_key = origin_info.CodeLocation(
+ file_path=f.name, line_number=line_number + num_import_lines)
+ source_mapping[source_map_key] = original_position
module_name = os.path.basename(f.name[:-3])
if source_prefix:
f.write(source_prefix)
@@ -72,4 +132,27 @@ def ast_to_object(
f.write(source)
if delete_on_exit:
atexit.register(lambda: os.remove(f.name))
- return imp.load_source(module_name, f.name), source
+ compiled_node = imp.load_source(module_name, f.name)
+
+ # TODO(znado): Clean this up so we don't need to attach it to the namespace.
+ # TODO(znado): This does not work for classes because their methods share a
+ # namespace.
+ # This attaches the source map which is needed for error handling. Note that
+ # api.to_graph copies this source map into an attribute of the function.
+ #
+ # We need this so the ag_source_map__ variable is available to the call to
+ # rewrite_graph_construction_error in the except block inside each function
+ # that handles graph construction errors.
+ #
+ # We cannot get the rewritten function name until it is too late so templating
+ # is hard, and this cleanly fixes the
+ # issues encountered with nested functions because this is attached to the
+ # outermost one.
+ source_map_name = 'ag_source_map__'
+ if source_map_name in compiled_node.__dict__:
+ raise ValueError('cannot convert %s because is has namespace attribute '
+ '"%s", which is reserved for AutoGraph.' %
+ (compiled_node, source_map_name))
+ compiled_node.__dict__[source_map_name] = source_mapping
+
+ return compiled_node, source
diff --git a/tensorflow/contrib/autograph/pyct/compiler_test.py b/tensorflow/contrib/autograph/pyct/compiler_test.py
index 98cdc1506b..e29fa9324c 100644
--- a/tensorflow/contrib/autograph/pyct/compiler_test.py
+++ b/tensorflow/contrib/autograph/pyct/compiler_test.py
@@ -59,14 +59,14 @@ class CompilerTest(test.TestCase):
value=gast.Str('c'))
])
+ source, _ = compiler.ast_to_source(node, indentation=' ')
self.assertEqual(
textwrap.dedent("""
if 1:
a = b
else:
a = 'c'
- """).strip(),
- compiler.ast_to_source(node, indentation=' ').strip())
+ """).strip(), source.strip())
def test_ast_to_object(self):
node = gast.FunctionDef(
diff --git a/tensorflow/contrib/autograph/pyct/origin_info.py b/tensorflow/contrib/autograph/pyct/origin_info.py
new file mode 100644
index 0000000000..614e346634
--- /dev/null
+++ b/tensorflow/contrib/autograph/pyct/origin_info.py
@@ -0,0 +1,100 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Container for origin source code information before AutoGraph compilation."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+import gast
+
+from tensorflow.contrib.autograph.pyct import anno
+from tensorflow.python.util import tf_inspect
+
+
+class CodeLocation(
+ collections.namedtuple('CodeLocation', ('file_path', 'line_number'))):
+ """Location of a line of code.
+
+ Attributes:
+ file_path: text, the full path to the file containing the code.
+ line_number: Int, the 1-based line number of the code in its file.
+ """
+ pass
+
+
+class OriginInfo(
+ collections.namedtuple('OriginInfo',
+ ('file_path', 'function_name', 'line_number',
+ 'column_offset', 'source_code_line'))):
+ """Container for information about the source code before conversion.
+
+ Instances of this class contain information about the source code that
+ transformed code originated from. Examples include:
+ * line number
+ * file name
+ * original user code
+ """
+
+ def as_frame(self):
+ """Makes a traceback frame tuple.
+
+ Returns:
+ A tuple of (file_path, line_number, function_name, source_code_line).
+ """
+ return (self.file_path, self.line_number, self.function_name,
+ self.source_code_line)
+
+
+# TODO(znado): Consider refactoring this into a Visitor.
+def resolve(node, source, function=None):
+ """Adds an origin information to all nodes inside the body of function.
+
+ Args:
+ node: The AST node for the function whose body nodes will be annotated.
+ source: Text, the source code string for the function whose body nodes will
+ be annotated.
+ function: Callable, the function that will have all nodes inside of it
+ annotation with an OriginInfo annotation with key anno.Basic.ORIGIN. If
+ it is None then only the line numbers and column offset will be set in the
+ annotation, with the rest of the information being None.
+
+ Returns:
+ A tuple of the AST node for function and a String containing its source
+ code.
+ """
+ if function:
+ _, function_lineno = tf_inspect.getsourcelines(function)
+ function_filepath = tf_inspect.getsourcefile(function)
+ else:
+ function_lineno = None
+ function_filepath = None
+ source_lines = source.split('\n')
+ for n in gast.walk(node):
+ if hasattr(n, 'lineno'):
+ # n.lineno is relative to the start of the enclosing function, so need to
+ # offset it by the line of the function.
+ source_code_line = source_lines[n.lineno - 1]
+ if function:
+ source_lineno = n.lineno + function_lineno - 1
+ function_name = function.__name__
+ else:
+ source_lineno = n.lineno
+ function_name = None
+ anno.setanno(
+ n, anno.Basic.ORIGIN,
+ OriginInfo(function_filepath, function_name, source_lineno,
+ n.col_offset, source_code_line))
diff --git a/tensorflow/contrib/autograph/pyct/qual_names.py b/tensorflow/contrib/autograph/pyct/qual_names.py
index da07013cf4..fb81404edc 100644
--- a/tensorflow/contrib/autograph/pyct/qual_names.py
+++ b/tensorflow/contrib/autograph/pyct/qual_names.py
@@ -30,6 +30,7 @@ import collections
import gast
from tensorflow.contrib.autograph.pyct import anno
+from tensorflow.contrib.autograph.pyct import parser
class Symbol(collections.namedtuple('Symbol', ['name'])):
@@ -89,7 +90,8 @@ class QN(object):
if not isinstance(base, (str, StringLiteral, NumberLiteral)):
# TODO(mdan): Require Symbol instead of string.
raise ValueError(
- 'For simple QNs, base must be a string or a Literal object.')
+ 'for simple QNs, base must be a string or a Literal object;'
+ ' got instead "%s"' % type(base))
assert '.' not in base and '[' not in base and ']' not in base
self._parent = None
self.qn = (base,)
@@ -113,6 +115,22 @@ class QN(object):
return self._parent
@property
+ def owner_set(self):
+ """Returns all the symbols (simple or composite) that own this QN.
+
+ In other words, if this symbol was modified, the symbols in the owner set
+ may also be affected.
+
+ Examples:
+ 'a.b[c.d]' has two owners, 'a' and 'a.b'
+ """
+ owners = set()
+ if self.has_attr() or self.has_subscript():
+ owners.add(self.parent)
+ owners.update(self.parent.owner_set)
+ return owners
+
+ @property
def support_set(self):
"""Returns the set of simple symbols that this QN relies on.
@@ -122,7 +140,7 @@ class QN(object):
Examples:
'a.b' has only one support symbol, 'a'
- 'a[i]' has two roots, 'a' and 'i'
+ 'a[i]' has two support symbols, 'a' and 'i'
"""
# TODO(mdan): This might be the set of Name nodes in the AST. Track those?
roots = set()
@@ -231,3 +249,9 @@ class QnResolver(gast.NodeTransformer):
def resolve(node):
return QnResolver().visit(node)
+
+
+def from_str(qn_str):
+ node = parser.parse_expression(qn_str)
+ node = resolve(node)
+ return anno.getanno(node, anno.Basic.QN)
diff --git a/tensorflow/contrib/autograph/pyct/qual_names_test.py b/tensorflow/contrib/autograph/pyct/qual_names_test.py
index 264afd508c..c793c2bb39 100644
--- a/tensorflow/contrib/autograph/pyct/qual_names_test.py
+++ b/tensorflow/contrib/autograph/pyct/qual_names_test.py
@@ -30,6 +30,15 @@ from tensorflow.python.platform import test
class QNTest(test.TestCase):
+ def test_from_str(self):
+ a = QN('a')
+ b = QN('b')
+ a_dot_b = QN(a, attr='b')
+ a_sub_b = QN(a, subscript=b)
+ self.assertEqual(qual_names.from_str('a.b'), a_dot_b)
+ self.assertEqual(qual_names.from_str('a'), a)
+ self.assertEqual(qual_names.from_str('a[b]'), a_sub_b)
+
def test_basic(self):
a = QN('a')
self.assertEqual(a.qn, ('a',))
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/BUILD b/tensorflow/contrib/autograph/pyct/static_analysis/BUILD
index bcf2dacec2..25f78536e0 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/BUILD
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/BUILD
@@ -19,8 +19,10 @@ py_library(
srcs = [
"activity.py",
"annos.py",
- "cfg.py",
+ "cfg.py", # TODO(mdan): Remove.
"live_values.py",
+ "liveness.py",
+ "reaching_definitions.py",
"type_info.py",
],
srcs_version = "PY2AND3",
@@ -28,6 +30,7 @@ py_library(
deps = [
"//tensorflow/contrib/autograph/pyct",
"//tensorflow/contrib/autograph/utils",
+ "//tensorflow/python:util",
"@gast_archive//:gast",
],
)
@@ -70,6 +73,37 @@ py_test(
],
)
+# TODO(mdan): Enable these tests once child change is in.
+py_test(
+ name = "liveness_test",
+ srcs = ["liveness_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "manual",
+ "notap",
+ ],
+ deps = [
+ ":static_analysis",
+ "//tensorflow/contrib/autograph/pyct",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+py_test(
+ name = "reaching_definitions_test",
+ srcs = ["reaching_definitions_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "manual",
+ "notap",
+ ],
+ deps = [
+ ":static_analysis",
+ "//tensorflow/contrib/autograph/pyct",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
py_test(
name = "type_info_test",
srcs = ["type_info_test.py"],
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/__init__.py b/tensorflow/contrib/autograph/pyct/static_analysis/__init__.py
index c325e19f28..9a82de735d 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/__init__.py
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/__init__.py
@@ -18,10 +18,14 @@ This module contains utilities to help annotate AST nodes with as much runtime
information as can be possibly extracted without actually executing the code,
under that assumption that the context in which the code will run is known.
-Note: It's a fair bet that this analysis cannot be reused across contexts
-without re-running it. In most cases, the context usually means referenced
-modules, which should be static enough to allow reuse, but that is not being
-reliably verified.
+Overall, the different analyses have the functions listed below:
+
+ * activity: inventories symbols read, written to, params, etc. at different
+ levels
+ * liveness, reaching_definitions: dataflow analyses based on the program's CFG
+ and using the symbol information gathered by activity analysis
+ * live_values, type_info: type and value inference based on dataflow
+ analysis and context information
"""
from __future__ import absolute_import
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/annos.py b/tensorflow/contrib/autograph/pyct/static_analysis/annos.py
index b929b35b79..5eefecf278 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/annos.py
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/annos.py
@@ -21,6 +21,9 @@ from __future__ import print_function
from enum import Enum
+# TODO(mdan): Remove.
+
+
class NoValue(Enum):
def __repr__(self):
@@ -50,10 +53,3 @@ class NodeAnno(NoValue):
ORELSE_SCOPE = (
'The scope for the orelse body of a statement (False branch for if '
'statements, orelse body for loops).')
-
- # Type and Value annotations
- # Type annotations are represented by objects of type type_info.Type.
- STATIC_INFO = (
- 'The type or value information that should be asserted about the entity '
- 'referenced by the symbol holding this annotation, irrespective of the '
- 'execution context.')
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/liveness.py b/tensorflow/contrib/autograph/pyct/static_analysis/liveness.py
new file mode 100644
index 0000000000..bf29d868a2
--- /dev/null
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/liveness.py
@@ -0,0 +1,200 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Live variable analysis.
+
+This analysis attaches a set containing the live symbols that are live at the
+exit of control flow statements.
+
+Requires activity analysis.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.contrib.autograph.pyct import anno
+from tensorflow.contrib.autograph.pyct import cfg
+from tensorflow.contrib.autograph.pyct import transformer
+from tensorflow.contrib.autograph.pyct.static_analysis import annos
+
+
+class Analyzer(cfg.GraphVisitor):
+ """CFG visitor that performs liveness analysis at statement level."""
+
+ def __init__(self, graph):
+ super(Analyzer, self).__init__(graph)
+ # This allows communicating that nodes generate extra symbols,
+ # e.g. those that a function definition closes over.
+ self.extra_gen = {}
+
+ def init_state(self, _):
+ return set()
+
+ def visit_node(self, node):
+ prev_live_in = self.in_[node]
+
+ if anno.hasanno(node.ast_node, anno.Static.SCOPE):
+ node_scope = anno.getanno(node.ast_node, anno.Static.SCOPE)
+
+ gen = node_scope.used | self.extra_gen.get(node.ast_node, frozenset())
+ # TODO(mdan): verify whether composites' parents need to be added.
+ # E.g. if x.y is live whether x needs to be added. Theoretically the
+ # activity analysis should have both so that wouldn't be needed.
+ kill = node_scope.modified
+
+ live_out = set()
+ for n in node.next:
+ live_out |= self.in_[n]
+ live_in = gen | (live_out - kill)
+
+ else:
+ # Nodes that don't have a scope annotation are assumed not to touch any
+ # symbols.
+ # This Name node below is a literal name, e.g. False
+ assert isinstance(node.ast_node,
+ (gast.Name, gast.Continue, gast.Break)), type(
+ node.ast_node)
+ live_in = prev_live_in
+ live_out = live_in
+
+ self.in_[node] = live_in
+ self.out[node] = live_out
+
+ # TODO(mdan): Move this to the superclass?
+ return prev_live_in != live_in
+
+
+class WholeTreeAnalyzer(transformer.Base):
+ """Runs liveness analysis on each of the functions defined in the AST.
+
+ If a function defined other local functions, those will have separate CFGs.
+ However, dataflow analysis needs to tie up these CFGs to properly emulate the
+ effect of closures. In the case of liveness, the parent function's live
+ variables must account for the variables that are live at the entry of each
+ subfunction. For example:
+
+ def foo():
+ # baz is live here
+ def bar():
+ print(baz)
+
+ This analyzer runs liveness analysis on each individual function, accounting
+ for the effect above.
+ """
+
+ def __init__(self, source_info, graphs):
+ super(WholeTreeAnalyzer, self).__init__(source_info)
+ self.graphs = graphs
+ self.current_analyzer = None
+ self.analyzers = {}
+
+ def visit_FunctionDef(self, node):
+ parent_analyzer = self.current_analyzer
+ subgraph = self.graphs[node]
+
+ # Postorder tree processing makes this a bit complicated:
+ # 1. construct an analyzer object and put it on stack
+ # 2. recursively walk the subtree; this will initialize the analyzer's
+ # in_ state properly (done in a block below)
+ # 3. run the final analysis
+ analyzer = Analyzer(subgraph)
+ self.current_analyzer = analyzer
+ node = self.generic_visit(node)
+ analyzer.visit_reverse()
+
+ if parent_analyzer is not None:
+ # Wire the state between the two subgraphs' analyzers.
+ child_in_state = analyzer.in_[subgraph.entry]
+ # Exception: symbols modified in the child function are local to it
+ body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
+ for qn in body_scope.modified:
+ # Note: a function modifying the symbol doesn't make that symbol
+ # live at the function's entry. In fact when that happens it is
+ # probably a case of undefined assignment, like this:
+ #
+ # bar = 0
+ # def foo():
+ # print(bar) # bar is undefined here!
+ # bar = 1
+ #
+ # Hence we use discard and not remove below.
+ child_in_state.discard(qn)
+ parent_analyzer.extra_gen[node] = frozenset(child_in_state,)
+
+ self.analyzers[node] = analyzer
+ self.current_analyzer = parent_analyzer
+ return node
+
+ def visit_nonlocal(self, node):
+ raise NotImplementedError()
+
+ def visit_global(self, node):
+ raise NotImplementedError()
+
+
+class Annotator(transformer.Base):
+ """AST visitor that annotates each control flow block with live symbols."""
+
+ # Note: additional nodes may be added as needed.
+
+ def __init__(self, source_info, cross_function_analyzer):
+ super(Annotator, self).__init__(source_info)
+ self.cross_function_analyzer = cross_function_analyzer
+ self.current_analyzer = None
+
+ def visit_FunctionDef(self, node):
+ parent_analyzer = self.current_analyzer
+ self.current_analyzer = self.cross_function_analyzer.analyzers[node]
+
+ node = self.generic_visit(node)
+ self.current_analyzer = parent_analyzer
+ return node
+
+ def _aggregate_successors_live_in(self, node):
+ successors = self.current_analyzer.graph.stmt_next[node]
+ node_live_out = set()
+ for s in successors:
+ node_live_out.update(self.current_analyzer.in_[s])
+ anno.setanno(node, anno.Static.LIVE_VARS_OUT, frozenset(node_live_out))
+ node = self.generic_visit(node)
+ return node
+
+ def visit_If(self, node):
+ return self._aggregate_successors_live_in(node)
+
+ def visit_For(self, node):
+ return self._aggregate_successors_live_in(node)
+
+ def visit_While(self, node):
+ return self._aggregate_successors_live_in(node)
+
+
+def resolve(node, source_info, graphs):
+ """Resolves the live symbols at the exit of control flow statements.
+
+ Args:
+ node: ast.AST
+ source_info: transformer.SourceInfo
+ graphs: Dict[ast.FunctionDef, cfg.Graph]
+ Returns:
+ ast.AST
+ """
+ cross_function_analyzer = WholeTreeAnalyzer(source_info, graphs)
+ node = cross_function_analyzer.visit(node)
+ visitor = Annotator(source_info, cross_function_analyzer)
+ node = visitor.visit(node)
+ return node
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/liveness_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/liveness_test.py
new file mode 100644
index 0000000000..d53adb28af
--- /dev/null
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/liveness_test.py
@@ -0,0 +1,149 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for liveness module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.autograph.pyct import anno
+from tensorflow.contrib.autograph.pyct import cfg
+from tensorflow.contrib.autograph.pyct import parser
+from tensorflow.contrib.autograph.pyct import qual_names
+from tensorflow.contrib.autograph.pyct import transformer
+from tensorflow.contrib.autograph.pyct.static_analysis import activity
+from tensorflow.contrib.autograph.pyct.static_analysis import liveness
+from tensorflow.python.platform import test
+
+
+class LivenessTest(test.TestCase):
+
+ def _parse_and_analyze(self, test_fn):
+ node, source = parser.parse_entity(test_fn)
+ entity_info = transformer.EntityInfo(
+ source_code=source,
+ source_file=None,
+ namespace={},
+ arg_values=None,
+ arg_types=None,
+ owner_type=None)
+ node = qual_names.resolve(node)
+ node = activity.resolve(node, entity_info)
+ graphs = cfg.build(node)
+ liveness.resolve(node, entity_info, graphs)
+ return node
+
+ def assertHasLiveOut(self, node, expected):
+ live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)
+ live_out_str = set(str(v) for v in live_out)
+ if not expected:
+ expected = ()
+ if not isinstance(expected, tuple):
+ expected = (expected,)
+ self.assertSetEqual(live_out_str, set(expected))
+
+ def test_stacked_if(self):
+
+ def test_fn(x, a):
+ if a > 0:
+ x = 0
+ if a > 1:
+ x = 1
+ return x
+
+ node = self._parse_and_analyze(test_fn)
+ fn_body = node.body[0].body
+
+ self.assertHasLiveOut(fn_body[0], ('a', 'x'))
+ self.assertHasLiveOut(fn_body[1], 'x')
+
+ def test_stacked_if_else(self):
+
+ def test_fn(x, a):
+ if a > 0:
+ x = 0
+ if a > 1:
+ x = 1
+ else:
+ x = 2
+ return x
+
+ node = self._parse_and_analyze(test_fn)
+ fn_body = node.body[0].body
+
+ self.assertHasLiveOut(fn_body[0], 'a')
+ self.assertHasLiveOut(fn_body[1], 'x')
+
+ def test_for_basic(self):
+
+ def test_fn(x, a):
+ for i in range(a):
+ x += i
+ return x
+
+ node = self._parse_and_analyze(test_fn)
+ fn_body = node.body[0].body
+
+ self.assertHasLiveOut(fn_body[0], 'x')
+
+ def test_attributes(self):
+
+ def test_fn(x, a):
+ if a > 0:
+ x.y = 0
+ return x.y
+
+ node = self._parse_and_analyze(test_fn)
+ fn_body = node.body[0].body
+
+ self.assertHasLiveOut(fn_body[0], ('x.y', 'x'))
+
+ def test_nested_functions(self):
+
+ def test_fn(a, b):
+ if b:
+ a = []
+
+ def foo():
+ return a
+
+ foo()
+
+ node = self._parse_and_analyze(test_fn)
+ fn_body = node.body[0].body
+
+ self.assertHasLiveOut(fn_body[0], 'a')
+
+ def test_nested_functions_isolation(self):
+
+ def test_fn(b):
+ if b:
+ a = 0 # pylint:disable=unused-variable
+
+ def child():
+ max(a) # pylint:disable=used-before-assignment
+ a = 1
+ return a
+
+ child()
+
+ node = self._parse_and_analyze(test_fn)
+ fn_body = node.body[0].body
+
+ self.assertHasLiveOut(fn_body[0], 'max')
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions.py b/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions.py
new file mode 100644
index 0000000000..4d79b0a56a
--- /dev/null
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions.py
@@ -0,0 +1,273 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Reaching definition analysis.
+
+This analysis attaches a set of a Definition objects to each symbol, one
+for each distinct definition that may reach it. The Definition objects are
+mutable and may be used by subsequent analyses to further annotate data like
+static type and value information.
+The analysis also attaches the set of the symbols defined at the entry of
+control flow statements.
+
+Requires activity analysis.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.contrib.autograph.pyct import anno
+from tensorflow.contrib.autograph.pyct import cfg
+from tensorflow.contrib.autograph.pyct import transformer
+from tensorflow.contrib.autograph.pyct.static_analysis import annos
+
+
+class Definition(object):
+ """Definition objects describe a unique definition of a variable.
+
+ Subclasses of this may be used by passing an appropriate factory fuction to
+ resolve.
+
+ Attributes:
+ param_of: Optional[ast.AST]
+ """
+
+ def __init__(self):
+ self.param_of = None
+
+ def __repr__(self):
+ return '%s[%d]' % (self.__class__.__name__, id(self))
+
+
+class _NodeState(object):
+ """Abstraction for the state of the CFG walk for reaching definition analysis.
+
+ This is a value type. Only implements the strictly necessary operators.
+
+ Attributes:
+ value: Dict[qual_names.QN, Set[Definition, ...]], the defined symbols and
+ their possible definitions
+ """
+
+ def __init__(self, init_from=None):
+ if init_from:
+ if isinstance(init_from, _NodeState):
+ self.value = {
+ s: set(other_infos) for s, other_infos in init_from.value.items()
+ }
+ elif isinstance(init_from, dict):
+ self.value = {s: set((init_from[s],)) for s in init_from}
+ else:
+ assert False, init_from
+ else:
+ self.value = {}
+
+ def __eq__(self, other):
+ if frozenset(self.value.keys()) != frozenset(other.value.keys()):
+ return False
+ ret = all(self.value[s] == other.value[s] for s in self.value)
+ return ret
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def __or__(self, other):
+ assert isinstance(other, _NodeState)
+ result = _NodeState(self)
+ for s, other_infos in other.value.items():
+ if s in result.value:
+ result.value[s].update(other_infos)
+ else:
+ result.value[s] = set(other_infos)
+ return result
+
+ def __sub__(self, other):
+ assert isinstance(other, set)
+ result = _NodeState(self)
+ for s in other:
+ result.value.pop(s, None)
+ return result
+
+ def __repr__(self):
+ return 'NodeState[%s]=%s' % (id(self), repr(self.value))
+
+
+class Analyzer(cfg.GraphVisitor):
+ """CFG visitor that determines reaching definitions at statement level."""
+
+ def __init__(self, graph, definition_factory):
+ self._definition_factory = definition_factory
+ super(Analyzer, self).__init__(graph)
+ self.defs_by_ast_node = {}
+ # This allows communicating that nodes have extra reaching definitions,
+ # e.g. those that a function closes over.
+ self.extra_in = {}
+
+ self.gen_map = {}
+
+ def init_state(self, _):
+ return _NodeState()
+
+ def visit_node(self, node):
+ prev_defs_out = self.out[node]
+
+ defs_in = _NodeState(self.extra_in.get(node.ast_node, None))
+ for n in node.prev:
+ defs_in |= self.out[n]
+
+ if anno.hasanno(node.ast_node, anno.Static.SCOPE):
+ node_scope = anno.getanno(node.ast_node, anno.Static.SCOPE)
+ # The definition objects created by each node must be singletons because
+ # their ids are used in equality checks.
+ if node not in self.gen_map:
+ node_symbols = {}
+ for s in node_scope.modified:
+ def_ = self._definition_factory()
+ if s in node_scope.params:
+ def_.param_of = node_scope.params[s]
+ node_symbols[s] = def_
+ self.gen_map[node] = _NodeState(node_symbols)
+
+ gen = self.gen_map[node]
+ kill = node_scope.modified
+ defs_out = gen | (defs_in - kill)
+
+ else:
+ # Nodes that don't have a scope annotation are assumed not to touch any
+ # symbols.
+ # This Name node below is a literal name, e.g. False
+ # This can also happen if activity.py forgot to annotate the node with a
+ # scope object.
+ assert isinstance(node.ast_node,
+ (gast.Name, gast.Break, gast.Continue)), (node.ast_node,
+ node)
+ defs_out = defs_in
+
+ self.in_[node] = defs_in
+ self.out[node] = defs_out
+ self.defs_by_ast_node[node.ast_node] = defs_out.value
+
+ # TODO(mdan): Move this to the superclass?
+ return prev_defs_out != defs_out
+
+
+class WholeTreeAnalyzer(transformer.Base):
+ """AST visitor that annotates each symbol name with its reaching definitions.
+
+ Simultaneously, the visitor runs the dataflow analysis on each function node,
+ accounting for the effect of closures. For example:
+
+ def foo():
+ bar = 1
+ def baz():
+ # bar = 1 reaches here
+ """
+
+ def __init__(self, source_info, graphs, definition_factory):
+ super(WholeTreeAnalyzer, self).__init__(source_info)
+ self.stmt_reaching_defs_info = None
+ self.graphs = graphs
+ self.current_analyzer = None
+ self.definition_factory = definition_factory
+ self.current_stmt_defs = None
+
+ def visit_FunctionDef(self, node):
+ parent_analyzer = self.current_analyzer
+ subgraph = self.graphs[node]
+
+ # Preorder tree processing:
+ # 1. if this is a child function, the parent was already analyzed and it
+ # has the proper state value for the subgraph's entry
+ # 2. analyze the current function body
+ # 2. recursively walk the subtree; child functions will be processed
+ analyzer = Analyzer(subgraph, self.definition_factory)
+ if parent_analyzer is not None:
+ # Wire the state between the two subgraphs' analyzers.
+ parent_out_state = parent_analyzer.out[parent_analyzer.graph.index[node]]
+ # Exception: symbols modified in the child function are local to it
+ body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
+ parent_out_state -= body_scope.modified
+ analyzer.extra_in[node.args] = parent_out_state
+
+ # Complete the analysis for the local function and annotate its body.
+ analyzer.visit_forward()
+
+ # Recursively process any remaining subfunctions.
+ self.current_analyzer = analyzer
+ node = self.generic_visit(node)
+ self.current_analyzer = parent_analyzer
+
+ return node
+
+ def visit_nonlocal(self, node):
+ raise NotImplementedError()
+
+ def visit_global(self, node):
+ raise NotImplementedError()
+
+ def visit_Name(self, node):
+ if self.current_analyzer is None:
+ # Names may appear outside function defs - for example in class
+ # definitions.
+ return node
+
+ qn = anno.getanno(node, anno.Basic.QN)
+ assert self.current_stmt_defs is not None, (
+ 'name node outside of any statement?')
+ anno.setanno(node, anno.Static.DEFINITIONS,
+ tuple(self.current_stmt_defs.get(qn, ())))
+ return node
+
+ def _aggregate_predecessors_defined_in(self, node):
+ preds = self.current_analyzer.graph.stmt_prev[node]
+ node_defined_in = set()
+ for p in preds:
+ node_defined_in |= set(self.current_analyzer.out[p].value.keys())
+ anno.setanno(node, anno.Static.DEFINED_VARS_IN, frozenset(node_defined_in))
+ node = self.generic_visit(node)
+ return node
+
+ def visit_If(self, node):
+ return self._aggregate_predecessors_defined_in(node)
+
+ def visit_For(self, node):
+ return self._aggregate_predecessors_defined_in(node)
+
+ def visit_While(self, node):
+ return self._aggregate_predecessors_defined_in(node)
+
+ def visit(self, node):
+ if (self.current_analyzer is not None and
+ node in self.current_analyzer.defs_by_ast_node):
+ self.current_stmt_defs = self.current_analyzer.defs_by_ast_node[node]
+ return super(WholeTreeAnalyzer, self).visit(node)
+
+
+def resolve(node, source_info, graphs, definition_factory):
+ """Resolves reaching definitions for each symbol.
+
+ Args:
+ node: ast.AST
+ source_info: transformer.SourceInfo
+ graphs: Dict[ast.FunctionDef, cfg.Graph]
+ definition_factory: Callable[[], Definition]
+ Returns:
+ ast.AST
+ """
+ visitor = WholeTreeAnalyzer(source_info, graphs, definition_factory)
+ node = visitor.visit(node)
+ return node
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions_test.py
new file mode 100644
index 0000000000..0410bb2a35
--- /dev/null
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions_test.py
@@ -0,0 +1,221 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for reaching_definitions module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.autograph.pyct import anno
+from tensorflow.contrib.autograph.pyct import cfg
+from tensorflow.contrib.autograph.pyct import parser
+from tensorflow.contrib.autograph.pyct import qual_names
+from tensorflow.contrib.autograph.pyct import transformer
+from tensorflow.contrib.autograph.pyct.static_analysis import activity
+from tensorflow.contrib.autograph.pyct.static_analysis import reaching_definitions
+from tensorflow.python.platform import test
+
+
+class DefinitionInfoTest(test.TestCase):
+
+ def _parse_and_analyze(self, test_fn):
+ node, source = parser.parse_entity(test_fn)
+ entity_info = transformer.EntityInfo(
+ source_code=source,
+ source_file=None,
+ namespace={},
+ arg_values=None,
+ arg_types=None,
+ owner_type=None)
+ node = qual_names.resolve(node)
+ node = activity.resolve(node, entity_info)
+ graphs = cfg.build(node)
+ node = reaching_definitions.resolve(node, entity_info, graphs,
+ reaching_definitions.Definition)
+ return node
+
+ def assertHasDefs(self, node, num):
+ defs = anno.getanno(node, anno.Static.DEFINITIONS)
+ self.assertEqual(len(defs), num)
+ for r in defs:
+ self.assertIsInstance(r, reaching_definitions.Definition)
+
+ def assertHasDefinedIn(self, node, expected):
+ defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN)
+ defined_in_str = set(str(v) for v in defined_in)
+ if not expected:
+ expected = ()
+ if not isinstance(expected, tuple):
+ expected = (expected,)
+ self.assertSetEqual(defined_in_str, set(expected))
+
+ def test_conditional(self):
+
+ def test_fn(a, b):
+ a = []
+ if b:
+ a = []
+ return a
+
+ node = self._parse_and_analyze(test_fn)
+ fn_body = node.body[0].body
+
+ self.assertHasDefs(fn_body[0].targets[0], 1)
+ self.assertHasDefs(fn_body[1].test, 1)
+ self.assertHasDefs(fn_body[1].body[0].targets[0], 1)
+ self.assertHasDefs(fn_body[2].value, 2)
+
+ self.assertHasDefinedIn(fn_body[1], ('a', 'b'))
+
+ def test_while(self):
+
+ def test_fn(a):
+ max(a)
+ while True:
+ a = a
+ a = a
+ return a
+
+ node = self._parse_and_analyze(test_fn)
+ fn_body = node.body[0].body
+
+ self.assertHasDefs(fn_body[0].value.args[0], 1)
+ self.assertHasDefs(fn_body[1].body[0].targets[0], 1)
+ self.assertHasDefs(fn_body[1].body[0].value, 1)
+ self.assertHasDefs(fn_body[1].body[1].targets[0], 1)
+ self.assertHasDefs(fn_body[1].body[1].value, 1)
+ # The loop does have an invariant test, but the CFG doesn't know that.
+ self.assertHasDefs(fn_body[2].value, 2)
+
+ def test_while_else(self):
+
+ def test_fn(x, i):
+ y = 0
+ while x:
+ x += i
+ if i:
+ break
+ else:
+ y = 1
+ return x, y
+
+ node = self._parse_and_analyze(test_fn)
+ fn_body = node.body[0].body
+
+ self.assertHasDefs(fn_body[0].targets[0], 1)
+ self.assertHasDefs(fn_body[1].test, 2)
+ self.assertHasDefs(fn_body[1].body[0].target, 1)
+ self.assertHasDefs(fn_body[1].body[1].test, 1)
+ self.assertHasDefs(fn_body[1].orelse[0].targets[0], 1)
+ self.assertHasDefs(fn_body[2].value.elts[0], 2)
+ self.assertHasDefs(fn_body[2].value.elts[1], 2)
+
+ def test_for_else(self):
+
+ def test_fn(x, i):
+ y = 0
+ for i in x:
+ x += i
+ if i:
+ break
+ else:
+ continue
+ else:
+ y = 1
+ return x, y
+
+ node = self._parse_and_analyze(test_fn)
+ fn_body = node.body[0].body
+
+ self.assertHasDefs(fn_body[0].targets[0], 1)
+ self.assertHasDefs(fn_body[1].target, 1)
+ self.assertHasDefs(fn_body[1].body[0].target, 1)
+ self.assertHasDefs(fn_body[1].body[1].test, 1)
+ self.assertHasDefs(fn_body[1].orelse[0].targets[0], 1)
+ self.assertHasDefs(fn_body[2].value.elts[0], 2)
+ self.assertHasDefs(fn_body[2].value.elts[1], 2)
+
+ def test_nested_functions(self):
+
+ def test_fn(a, b):
+ a = []
+ if b:
+ a = []
+
+ def foo():
+ return a
+
+ foo()
+
+ return a
+
+ node = self._parse_and_analyze(test_fn)
+ fn_body = node.body[0].body
+ def_of_a_in_if = fn_body[1].body[0].targets[0]
+
+ self.assertHasDefs(fn_body[0].targets[0], 1)
+ self.assertHasDefs(fn_body[1].test, 1)
+ self.assertHasDefs(def_of_a_in_if, 1)
+ self.assertHasDefs(fn_body[2].value, 2)
+
+ inner_fn_body = fn_body[1].body[1].body
+ self.assertHasDefs(inner_fn_body[0].value, 1)
+ self.assertTrue(
+ anno.getanno(inner_fn_body[0].value, anno.Static.DEFINITIONS)[0] is
+ anno.getanno(def_of_a_in_if, anno.Static.DEFINITIONS)[0])
+
+ def test_nested_functions_isolation(self):
+
+ def test_fn(a):
+ a = 0
+
+ def child():
+ a = 1
+ return a
+
+ child()
+ return a
+
+ node = self._parse_and_analyze(test_fn)
+ fn_body = node.body[0].body
+
+ self.assertHasDefs(fn_body[3].value, 1)
+ self.assertHasDefs(fn_body[1].body[1].value, 1)
+
+ parent_return = fn_body[3]
+ child_return = fn_body[1].body[1]
+ # The assignment `a = 1` makes `a` local to `child`.
+ self.assertFalse(
+ anno.getanno(parent_return.value, anno.Static.DEFINITIONS)[0] is
+ anno.getanno(child_return.value, anno.Static.DEFINITIONS)[0])
+
+ def test_debug(self):
+
+ def foo(_):
+ pass
+
+ def test_fn(a):
+ with foo(a):
+ return a
+
+ node = self._parse_and_analyze(test_fn)
+ fn_body = node.body[0].body
+
+ self.assertHasDefs(fn_body[0].items[0].context_expr.func, 0)
+ self.assertHasDefs(fn_body[0].items[0].context_expr.args[0], 1)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/autograph/pyct/templates.py b/tensorflow/contrib/autograph/pyct/templates.py
index 9c479ebc2f..72d1d3b269 100644
--- a/tensorflow/contrib/autograph/pyct/templates.py
+++ b/tensorflow/contrib/autograph/pyct/templates.py
@@ -26,6 +26,7 @@ import textwrap
import gast
+from tensorflow.contrib.autograph.pyct import anno
from tensorflow.contrib.autograph.pyct import ast_util
from tensorflow.contrib.autograph.pyct import parser
from tensorflow.contrib.autograph.pyct import qual_names
@@ -43,39 +44,65 @@ class ReplaceTransformer(gast.NodeTransformer):
"""
self.replacements = replacements
self.in_replacements = False
+ self.preserved_annos = {
+ anno.Basic.ORIGIN,
+ anno.Basic.SKIP_PROCESSING,
+ anno.Static.ORIG_DEFINITIONS,
+ }
+
+ def _prepare_replacement(self, replaced, key):
+ """Prepares a replacement AST that's safe to swap in for a node.
+
+ Args:
+ replaced: ast.AST, the node being replaced
+ key: Hashable, the key of the replacement AST
+ Returns:
+ ast.AST, the replacement AST
+ """
+ repl = self.replacements[key]
+
+ new_nodes = ast_util.copy_clean(repl, preserve_annos=self.preserved_annos)
+ if isinstance(new_nodes, gast.AST):
+ new_nodes = [new_nodes]
+
+ return new_nodes
def visit_Expr(self, node):
- if (isinstance(node.value, gast.Name) and
- node.value.id in self.replacements):
- return self.visit(node.value)
- self.generic_visit(node)
- return node
+ # When replacing a placeholder with an entire statement, the replacement
+ # must stand on its own and not be wrapped in an Expr.
+ new_value = self.visit(node.value)
+ if new_value is node.value:
+ return node
+ return new_value
def visit_keyword(self, node):
- if node.arg in self.replacements:
- repl = self.replacements[node.arg]
- if isinstance(repl, gast.keyword):
- return repl
- elif (isinstance(repl, (list, tuple)) and repl and
- all(isinstance(r, gast.keyword) for r in repl)):
- return repl
- # TODO(mdan): We may allow replacing with a string as well.
- # For example, if one wanted to replace foo with bar in foo=baz, then
- # we could allow changing just node arg, so that we end up with bar=baz.
- raise ValueError(
- 'a keyword argument may only be replaced by another keyword or a '
- 'non-empty list of keywords. Found: %s' % repl)
- return self.generic_visit(node)
+ if node.arg not in self.replacements:
+ return self.generic_visit(node)
+
+ repl = self._prepare_replacement(node, node.arg)
+ if isinstance(repl, gast.keyword):
+ return repl
+ elif (repl and isinstance(repl, (list, tuple)) and
+ all(isinstance(r, gast.keyword) for r in repl)):
+ return repl
+ # TODO(mdan): We may allow replacing with a string as well.
+ # For example, if one wanted to replace foo with bar in foo=baz, then
+ # we could allow changing just node arg, so that we end up with bar=baz.
+ raise ValueError(
+ 'a keyword argument may only be replaced by another keyword or a '
+ 'non-empty list of keywords. Found: %s' % repl)
def visit_FunctionDef(self, node):
node = self.generic_visit(node)
- if node.name in self.replacements:
- repl = self.replacements[node.name]
- if not isinstance(repl, (gast.Name, ast.Name)):
- raise ValueError(
- 'a function name can only be replaced by a Name node. Found: %s' %
- repl)
- node.name = repl.id
+ if node.name not in self.replacements:
+ return node
+
+ repl = self.replacements[node.name]
+ if not isinstance(repl, (gast.Name, ast.Name)):
+ raise ValueError(
+ 'a function name can only be replaced by a Name node. Found: %s' %
+ repl)
+ node.name = repl.id
return node
def _check_has_context(self, node):
@@ -148,6 +175,7 @@ class ReplaceTransformer(gast.NodeTransformer):
node = self.generic_visit(node)
if node.attr not in self.replacements:
return node
+
repl = self.replacements[node.attr]
if not isinstance(repl, gast.Name):
raise ValueError(
@@ -159,9 +187,7 @@ class ReplaceTransformer(gast.NodeTransformer):
if node.id not in self.replacements:
return node
- new_nodes = ast_util.copy_clean(self.replacements[node.id])
- if isinstance(new_nodes, gast.AST):
- new_nodes = [new_nodes]
+ new_nodes = self._prepare_replacement(node, node.id)
# Preserve the target context.
for n in new_nodes:
@@ -182,7 +208,7 @@ class ReplaceTransformer(gast.NodeTransformer):
def _convert_to_ast(n):
- """Convert from a known data type to AST."""
+ """Converts from a known data type to AST."""
if isinstance(n, str):
# Note: the node will receive the ctx value from the template, see
# ReplaceTransformer.visit_Name.
@@ -197,7 +223,7 @@ def _convert_to_ast(n):
def replace(template, **replacements):
- """Replace placeholders in a Python template.
+ """Replaces placeholders in a Python template.
AST Name and Tuple nodes always receive the context that inferred from
the template. However, when replacing more complex nodes (that can potentially
diff --git a/tensorflow/contrib/autograph/pyct/transformer.py b/tensorflow/contrib/autograph/pyct/transformer.py
index 7655811830..bbdfefc50a 100644
--- a/tensorflow/contrib/autograph/pyct/transformer.py
+++ b/tensorflow/contrib/autograph/pyct/transformer.py
@@ -59,6 +59,103 @@ class EntityInfo(object):
self.owner_type = owner_type
+class _StateStack(object):
+ """Typed stack abstraction.
+
+ This class provides syntactic sugar for a stack of objects of known
+ type. It allows accessing attributes of the object at the top of the stack
+ directly against this object, which allows for very terse syntax.
+
+ For example, this code:
+
+ stack = _StateStack(Foo)
+ stack.enter()
+ stack.bar
+
+ Is equivalent to:
+
+ stack = []
+ stack.append(Foo())
+ foo = stack[-1]
+ foo.bar
+
+ See _State for more on how this is used.
+
+ Attributes:
+ type: Any, the type of objects that this stack holds
+ level: int, the current stack depth
+ value: Any, the instance of the object at the top of the stack
+ """
+
+ def __init__(self, type_):
+ # Because we override __setattr__, we need to attach these attributes using
+ # the superclass' setattr.
+ object.__setattr__(self, 'type', type_)
+ object.__setattr__(self, '_stack', [])
+ self.enter()
+
+ def enter(self):
+ self._stack.append(self.type())
+
+ def exit(self):
+ return self._stack.pop()
+
+ @property
+ def level(self):
+ return len(self._stack)
+
+ @property
+ def value(self):
+ return self._stack[-1]
+
+ def __getattr__(self, key):
+ return getattr(self._stack[-1], key)
+
+ def __setattr__(self, key, value):
+ setattr(self._stack[-1], key, value)
+
+
+class _State(object):
+ """Supporting class for nested scope variable space for converter.Base.
+
+ This structure offers syntactic sugar over a dict of stacks of objects
+ of known type. These structures are useful to keep state during AST walks.
+ Multiple different scopes can be tracked in parallel. For example:
+
+ s = _State()
+
+ s[foo].enter()
+ s[bar].enter() # this will not affect s[foo]
+
+ Element access has special semantics:
+ * keys are a data type
+ * element values are _StateStack(type=key) objects
+ * missing elements are automatically added, similarly to defaultdict
+
+ For example, the following block :
+
+ _State s
+ s[Foo]
+
+ Is equivalent to:
+
+ s = {}
+ if Foo not in s:
+ s[Foo] = Foo()
+ s[Foo]
+
+ See Base for how it's used.
+ """
+
+ def __init__(self):
+ self._value = {}
+
+ def __getitem__(self, key):
+ if key not in self._value:
+ self._value[key] = _StateStack(key)
+ return self._value[key]
+
+
class Base(gast.NodeTransformer):
"""Base class for general-purpose code transformers transformers.
@@ -71,6 +168,27 @@ class Base(gast.NodeTransformer):
(possibly nested) scopes, use enter/exit_local_scope and set/get_local.
You must call enter/exit_local_scope manually, but the transformer detects
when they are not properly paired.
+
+ The transformer allows keeping state across calls to visit_* that is local to
+ arbitrary nodes and their descendants, using the self.state attribute.
+ Multiple independent scopes are allowed and automatically constructed.
+
+ For example, to keep track of the If node that encloses any Name node, one can
+ write:
+
+ class FooType(object):
+
+ def __init__(self):
+ self.foo_property = None
+
+ class DummyTransformer(Base):
+
+ def visit_If(self, node):
+ self.state[FooType].enter()
+ self.state[FooType].foo_property = node
+
+ def visit_Name(self, node):
+ self.state[FooType].foo_property # will hold the innermost enclosing if
"""
# TODO(mdan): Document all extra features.
@@ -92,6 +210,12 @@ class Base(gast.NodeTransformer):
self._local_scope_state = []
self.enter_local_scope()
+ # Allows scoping of local variables to keep state across calls to visit_*
+ # methods. Multiple scope hierchies may exist and are keyed by tag. A scope
+ # is valid at one or more nodes and all its children. Scopes created in
+ # child nodes supersede their parent. Scopes are isolated from one another.
+ self.state = _State()
+
@property
def enclosing_entities(self):
return tuple(self._enclosing_entities)
@@ -101,7 +225,9 @@ class Base(gast.NodeTransformer):
return len(self._local_scope_state)
def enter_local_scope(self, inherit=None):
- """Marks entry into a new local scope.
+ """Deprecated. Use self.state instead.
+
+ Marks entry into a new local scope.
Args:
inherit: Optional enumerable of variable names to copy from the
@@ -116,7 +242,9 @@ class Base(gast.NodeTransformer):
self._local_scope_state.append(scope_entered)
def exit_local_scope(self, keep=None):
- """Marks exit from the current local scope.
+ """Deprecated. Use self.state instead.
+
+ Marks exit from the current local scope.
Args:
keep: Optional enumerable of variable names to copy into the
@@ -133,9 +261,11 @@ class Base(gast.NodeTransformer):
return scope_left
def set_local(self, name, value):
+ """Deprecated. Use self.state instead."""
self._local_scope_state[-1][name] = value
def get_local(self, name, default=None):
+ """Deprecated. Use self.state instead."""
return self._local_scope_state[-1].get(name, default)
def debug_print(self, node):
@@ -216,7 +346,7 @@ class Base(gast.NodeTransformer):
node_destination = new_destination
return results
- # TODO(mdan): Once we have error tracing, we may be able to just go to SSA.
+ # TODO(mdan): Remove.
def apply_to_single_assignments(self, targets, values, apply_fn):
"""Applies a function to each individual assignment.
@@ -266,7 +396,8 @@ class Base(gast.NodeTransformer):
def _get_source(self, node):
try:
- return compiler.ast_to_source(node)
+ source, _ = compiler.ast_to_source(node)
+ return source
except AssertionError:
return '<could not convert AST to source>'
diff --git a/tensorflow/contrib/autograph/pyct/transformer_test.py b/tensorflow/contrib/autograph/pyct/transformer_test.py
index baf04653ae..19b80b09ac 100644
--- a/tensorflow/contrib/autograph/pyct/transformer_test.py
+++ b/tensorflow/contrib/autograph/pyct/transformer_test.py
@@ -93,6 +93,83 @@ class TransformerTest(test.TestCase):
inner_function, lambda_node),
anno.getanno(lambda_expr, 'enclosing_entities'))
+ def assertSameAnno(self, first, second, key):
+ self.assertIs(anno.getanno(first, key), anno.getanno(second, key))
+
+ def assertDifferentAnno(self, first, second, key):
+ self.assertIsNot(anno.getanno(first, key), anno.getanno(second, key))
+
+ def test_state_tracking(self):
+
+ class LoopState(object):
+ pass
+
+ class CondState(object):
+ pass
+
+ class TestTransformer(transformer.Base):
+
+ def visit(self, node):
+ anno.setanno(node, 'loop_state', self.state[LoopState].value)
+ anno.setanno(node, 'cond_state', self.state[CondState].value)
+ return super(TestTransformer, self).visit(node)
+
+ def visit_While(self, node):
+ self.state[LoopState].enter()
+ node = self.generic_visit(node)
+ self.state[LoopState].exit()
+ return node
+
+ def visit_If(self, node):
+ self.state[CondState].enter()
+ node = self.generic_visit(node)
+ self.state[CondState].exit()
+ return node
+
+ tr = TestTransformer(self._simple_source_info())
+
+ def test_function(a):
+ a = 1
+ while a:
+ _ = 'a'
+ if a > 2:
+ _ = 'b'
+ while True:
+ raise '1'
+ if a > 3:
+ _ = 'c'
+ while True:
+ raise '1'
+
+ node, _ = parser.parse_entity(test_function)
+ node = tr.visit(node)
+
+ fn_body = node.body[0].body
+ outer_while_body = fn_body[1].body
+ self.assertSameAnno(fn_body[0], outer_while_body[0], 'cond_state')
+ self.assertDifferentAnno(fn_body[0], outer_while_body[0], 'loop_state')
+
+ first_if_body = outer_while_body[1].body
+ self.assertDifferentAnno(outer_while_body[0], first_if_body[0],
+ 'cond_state')
+ self.assertSameAnno(outer_while_body[0], first_if_body[0], 'loop_state')
+
+ first_inner_while_body = first_if_body[1].body
+ self.assertSameAnno(first_if_body[0], first_inner_while_body[0],
+ 'cond_state')
+ self.assertDifferentAnno(first_if_body[0], first_inner_while_body[0],
+ 'loop_state')
+
+ second_if_body = outer_while_body[2].body
+ self.assertDifferentAnno(first_if_body[0], second_if_body[0], 'cond_state')
+ self.assertSameAnno(first_if_body[0], second_if_body[0], 'loop_state')
+
+ second_inner_while_body = second_if_body[1].body
+ self.assertDifferentAnno(first_inner_while_body[0],
+ second_inner_while_body[0], 'cond_state')
+ self.assertDifferentAnno(first_inner_while_body[0],
+ second_inner_while_body[0], 'loop_state')
+
def test_local_scope_info_stack(self):
class TestTransformer(transformer.Base):
diff --git a/tensorflow/contrib/batching/python/ops/batch_ops.py b/tensorflow/contrib/batching/python/ops/batch_ops.py
index 47b80bdf4a..55faad983f 100644
--- a/tensorflow/contrib/batching/python/ops/batch_ops.py
+++ b/tensorflow/contrib/batching/python/ops/batch_ops.py
@@ -58,8 +58,6 @@ def batch_function(num_batch_threads,
max_batch_size,
batch_timeout_micros,
allowed_batch_sizes=None,
- grad_timeout_micros=60 * 1000 * 1000,
- unbatch_timeout_micros=60 * 1000 * 1000,
max_enqueued_batches=10):
"""Batches the computation done by the decorated function.
@@ -94,10 +92,6 @@ def batch_function(num_batch_threads,
does nothing. Otherwise, supplies a list of batch sizes, causing the op
to pad batches up to one of those sizes. The entries must increase
monotonically, and the final entry must equal max_batch_size.
- grad_timeout_micros: The timeout to use for the gradient. See the
- documentation of the unbatch op for more details. Defaults to 60s.
- unbatch_timeout_micros: The timeout to use for unbatching. See the
- documentation of the unbatch op for more details. Defaults to 60s.
max_enqueued_batches: The maximum depth of the batch queue. Defaults to 10.
Returns:
diff --git a/tensorflow/contrib/bigtable/BUILD b/tensorflow/contrib/bigtable/BUILD
index 5c15d21e35..71538e0770 100644
--- a/tensorflow/contrib/bigtable/BUILD
+++ b/tensorflow/contrib/bigtable/BUILD
@@ -31,6 +31,7 @@ tf_custom_op_py_library(
srcs_version = "PY2AND3",
deps = [
":bigtable_ops",
+ "//tensorflow/contrib/data/python/ops:interleave_ops",
"//tensorflow/contrib/util:util_py",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform",
@@ -39,18 +40,24 @@ tf_custom_op_py_library(
],
)
+KERNEL_FILES = [
+ "kernels/bigtable_kernels.cc",
+ "kernels/bigtable_lookup_dataset_op.cc",
+ "kernels/bigtable_prefix_key_dataset_op.cc",
+ "kernels/bigtable_range_key_dataset_op.cc",
+ "kernels/bigtable_sample_keys_dataset_op.cc",
+ "kernels/bigtable_sample_key_pairs_dataset_op.cc",
+ "kernels/bigtable_scan_dataset_op.cc",
+]
+
tf_custom_op_library(
name = "python/ops/_bigtable.so",
- srcs = [
- "kernels/bigtable_kernels.cc",
- "kernels/bigtable_lookup_dataset_op.cc",
- "kernels/bigtable_prefix_key_dataset_op.cc",
- "kernels/bigtable_range_key_dataset_op.cc",
- "kernels/bigtable_scan_dataset_op.cc",
+ srcs = KERNEL_FILES + [
"ops/bigtable_ops.cc",
],
deps = [
":bigtable_lib_cc",
+ ":bigtable_range_helpers",
"@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/bigtable:bigtable_client",
],
)
@@ -69,15 +76,10 @@ tf_gen_op_libs(
tf_kernel_library(
name = "bigtable_kernels",
- srcs = [
- "kernels/bigtable_kernels.cc",
- "kernels/bigtable_lookup_dataset_op.cc",
- "kernels/bigtable_prefix_key_dataset_op.cc",
- "kernels/bigtable_range_key_dataset_op.cc",
- "kernels/bigtable_scan_dataset_op.cc",
- ],
+ srcs = KERNEL_FILES,
deps = [
":bigtable_lib_cc",
+ ":bigtable_range_helpers",
"//tensorflow/core:framework_headers_lib",
"//third_party/eigen3",
"@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/bigtable:bigtable_client",
@@ -97,6 +99,15 @@ cc_library(
)
cc_library(
+ name = "bigtable_range_helpers",
+ srcs = ["kernels/bigtable_range_helpers.cc"],
+ hdrs = ["kernels/bigtable_range_helpers.h"],
+ deps = [
+ "//tensorflow/core:framework_headers_lib",
+ ],
+)
+
+cc_library(
name = "bigtable_test_client",
srcs = ["kernels/test_kernels/bigtable_test_client.cc"],
hdrs = ["kernels/test_kernels/bigtable_test_client.h"],
@@ -120,6 +131,17 @@ tf_cc_test(
],
)
+tf_cc_test(
+ name = "bigtable_range_helpers_test",
+ size = "small",
+ srcs = ["kernels/bigtable_range_helpers_test.cc"],
+ deps = [
+ ":bigtable_range_helpers",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
tf_gen_op_wrapper_py(
name = "bigtable_test_ops",
deps = [":bigtable_test_ops_op_lib"],
@@ -168,11 +190,6 @@ tf_custom_op_py_library(
srcs_version = "PY2AND3",
deps = [
":bigtable_test_ops",
- # "//tensorflow/contrib/util:util_py",
- # "//tensorflow/python:framework_for_generated_wrappers",
- # "//tensorflow/python:platform",
- # "//tensorflow/python:util",
- # "//tensorflow/python/data",
],
)
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc
index f43b44f2cb..70923e6287 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc
@@ -40,7 +40,16 @@ class BigtableClientOp : public OpKernel {
if (connection_pool_size_ == -1) {
connection_pool_size_ = 100;
}
- OP_REQUIRES(ctx, connection_pool_size_ > 0,
+
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("max_receive_message_size",
+ &max_receive_message_size_));
+ // If left unset by the client code, set it to a default of 100. Note: the
+ // cloud-cpp default of 4 concurrent connections is far too low for high
+ // performance streaming.
+ if (max_receive_message_size_ == -1) {
+ max_receive_message_size_ = 1 << 24; // 16 MBytes
+ }
+ OP_REQUIRES(ctx, max_receive_message_size_ > 0,
errors::InvalidArgument("connection_pool_size must be > 0"));
}
@@ -67,7 +76,15 @@ class BigtableClientOp : public OpKernel {
cinfo_.container(), cinfo_.name(), &resource,
[this, ctx](
BigtableClientResource** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- auto client_options = google::cloud::bigtable::ClientOptions();
+ auto client_options =
+ google::cloud::bigtable::ClientOptions()
+ .set_connection_pool_size(connection_pool_size_)
+ .set_data_endpoint("batch-bigtable.googleapis.com");
+ auto channel_args = client_options.channel_arguments();
+ channel_args.SetMaxReceiveMessageSize(
+ max_receive_message_size_);
+ channel_args.SetUserAgentPrefix("tensorflow");
+ client_options.set_channel_arguments(channel_args);
std::shared_ptr<google::cloud::bigtable::DataClient> client =
google::cloud::bigtable::CreateDefaultDataClient(
project_id_, instance_id_, std::move(client_options));
@@ -87,6 +104,7 @@ class BigtableClientOp : public OpKernel {
string project_id_;
string instance_id_;
int64 connection_pool_size_;
+ int32 max_receive_message_size_;
mutex mu_;
ContainerInfo cinfo_ GUARDED_BY(mu_);
@@ -240,6 +258,12 @@ class ToBigtableOp : public AsyncOpKernel {
grpc::Status mutation_status;
std::vector<::google::cloud::bigtable::FailedMutation> failures =
resource->table().BulkApply(std::move(mutation), mutation_status);
+ if (!mutation_status.ok()) {
+ LOG(ERROR) << "Failure applying mutation: "
+ << mutation_status.error_code() << " - "
+ << mutation_status.error_message() << " ("
+ << mutation_status.error_details() << ").";
+ }
if (!failures.empty()) {
for (const auto& failure : failures) {
LOG(ERROR) << "Failure applying mutation on row ("
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lib.h b/tensorflow/contrib/bigtable/kernels/bigtable_lib.h
index 12d8256dea..a2a5df1037 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_lib.h
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_lib.h
@@ -58,7 +58,8 @@ class BigtableTableResource : public ResourceBase {
BigtableTableResource(BigtableClientResource* client, string table_name)
: client_(client),
table_name_(std::move(table_name)),
- table_(client->get_client(), table_name_) {
+ table_(client->get_client(), table_name_,
+ google::cloud::bigtable::AlwaysRetryMutationPolicy()) {
client_->Ref();
}
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.cc b/tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.cc
new file mode 100644
index 0000000000..51965f6214
--- /dev/null
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.cc
@@ -0,0 +1,68 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h"
+
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+namespace {
+
+string MakePrefixEndKey(const string& prefix) {
+ string end = prefix;
+ while (true) {
+ if (end.empty()) {
+ return end;
+ }
+ ++end[end.size() - 1];
+ if (end[end.size() - 1] == 0) {
+ // Handle wraparound case.
+ end = end.substr(0, end.size() - 1);
+ } else {
+ return end;
+ }
+ }
+}
+
+} // namespace
+
+/* static */ MultiModeKeyRange MultiModeKeyRange::FromPrefix(string prefix) {
+ string end = MakePrefixEndKey(prefix);
+ VLOG(1) << "Creating MultiModeKeyRange from Prefix: " << prefix
+ << ", with end key: " << end;
+ return MultiModeKeyRange(std::move(prefix), std::move(end));
+}
+
+/* static */ MultiModeKeyRange MultiModeKeyRange::FromRange(string begin,
+ string end) {
+ return MultiModeKeyRange(std::move(begin), std::move(end));
+}
+
+const string& MultiModeKeyRange::begin_key() const { return begin_; }
+
+const string& MultiModeKeyRange::end_key() const { return end_; }
+
+bool MultiModeKeyRange::contains_key(StringPiece key) const {
+ if (StringPiece(begin_) > key) {
+ return false;
+ }
+ if (StringPiece(end_) <= key && !end_.empty()) {
+ return false;
+ }
+ return true;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h b/tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h
new file mode 100644
index 0000000000..44c628e366
--- /dev/null
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h
@@ -0,0 +1,67 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_RANGE_HELPERS_H_
+#define TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_RANGE_HELPERS_H_
+
+#include <string>
+
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+// Represents a continuous range of keys defined by either a prefix or a range.
+//
+// Ranges are represented as "half-open", where the beginning key is included
+// in the range, and the end_key is the first excluded key after the range.
+//
+// The range of keys can be specified either by a key prefix, or by an explicit
+// begin key and end key. All methods on this class are valid no matter which
+// way the range was specified.
+//
+// Example:
+// MultiModeKeyRange range = MultiModeKeyRange::FromPrefix("myPrefix");
+// if (range.contains_key("myPrefixedKey")) {
+// LOG(INFO) << "range from " << range.begin_key() << " to "
+// << range.end_key() << "contains \"myPrefixedKey\"";
+// }
+// if (!range.contains_key("randomKey")) {
+// LOG(INFO) << "range does not contain \"randomKey\"";
+// }
+// range = MultiModeKeyRange::FromRange("a_start_key", "z_end_key");
+class MultiModeKeyRange {
+ public:
+ static MultiModeKeyRange FromPrefix(string prefix);
+ static MultiModeKeyRange FromRange(string begin, string end);
+
+ // The first valid key in the range.
+ const string& begin_key() const;
+ // The first invalid key after the valid range.
+ const string& end_key() const;
+ // Returns true if the provided key is a part of the range, false otherwise.
+ bool contains_key(StringPiece key) const;
+
+ private:
+ MultiModeKeyRange(string begin, string end)
+ : begin_(std::move(begin)), end_(std::move(end)) {}
+
+ const string begin_;
+ const string end_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_RANGE_HELPERS_H_
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_range_helpers_test.cc b/tensorflow/contrib/bigtable/kernels/bigtable_range_helpers_test.cc
new file mode 100644
index 0000000000..1bfc547271
--- /dev/null
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_range_helpers_test.cc
@@ -0,0 +1,107 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+TEST(MultiModeKeyRangeTest, SimplePrefix) {
+ MultiModeKeyRange r = MultiModeKeyRange::FromPrefix("prefix");
+ EXPECT_EQ("prefix", r.begin_key());
+ EXPECT_EQ("prefiy", r.end_key());
+ EXPECT_TRUE(r.contains_key("prefixed_key"));
+ EXPECT_FALSE(r.contains_key("not-prefixed-key"));
+ EXPECT_FALSE(r.contains_key("prefi"));
+ EXPECT_FALSE(r.contains_key("prefiy"));
+ EXPECT_FALSE(r.contains_key("early"));
+ EXPECT_FALSE(r.contains_key(""));
+}
+
+TEST(MultiModeKeyRangeTest, Range) {
+ MultiModeKeyRange r = MultiModeKeyRange::FromRange("a", "b");
+ EXPECT_EQ("a", r.begin_key());
+ EXPECT_EQ("b", r.end_key());
+ EXPECT_TRUE(r.contains_key("a"));
+ EXPECT_TRUE(r.contains_key("ab"));
+ EXPECT_FALSE(r.contains_key("b"));
+ EXPECT_FALSE(r.contains_key("bc"));
+ EXPECT_FALSE(r.contains_key("A"));
+ EXPECT_FALSE(r.contains_key("B"));
+ EXPECT_FALSE(r.contains_key(""));
+}
+
+TEST(MultiModeKeyRangeTest, InvertedRange) {
+ MultiModeKeyRange r = MultiModeKeyRange::FromRange("b", "a");
+ EXPECT_FALSE(r.contains_key("a"));
+ EXPECT_FALSE(r.contains_key("b"));
+ EXPECT_FALSE(r.contains_key(""));
+}
+
+TEST(MultiModeKeyRangeTest, EmptyPrefix) {
+ MultiModeKeyRange r = MultiModeKeyRange::FromPrefix("");
+ EXPECT_EQ("", r.begin_key());
+ EXPECT_EQ("", r.end_key());
+ EXPECT_TRUE(r.contains_key(""));
+ EXPECT_TRUE(r.contains_key("a"));
+ EXPECT_TRUE(r.contains_key("z"));
+ EXPECT_TRUE(r.contains_key("A"));
+ EXPECT_TRUE(r.contains_key("ZZZZZZ"));
+}
+
+TEST(MultiModeKeyRangeTest, HalfRange) {
+ MultiModeKeyRange r = MultiModeKeyRange::FromRange("start", "");
+ EXPECT_EQ("start", r.begin_key());
+ EXPECT_EQ("", r.end_key());
+ EXPECT_TRUE(r.contains_key("start"));
+ EXPECT_TRUE(r.contains_key("starting"));
+ EXPECT_TRUE(r.contains_key("z-end"));
+ EXPECT_FALSE(r.contains_key(""));
+ EXPECT_FALSE(r.contains_key("early"));
+}
+
+TEST(MultiModeKeyRangeTest, PrefixWrapAround) {
+ string prefix = "abc\xff";
+ MultiModeKeyRange r = MultiModeKeyRange::FromPrefix(prefix);
+ EXPECT_EQ(prefix, r.begin_key());
+ EXPECT_EQ("abd", r.end_key());
+
+ EXPECT_TRUE(r.contains_key("abc\xff\x07"));
+ EXPECT_TRUE(r.contains_key("abc\xff\x15"));
+ EXPECT_TRUE(r.contains_key("abc\xff\x61"));
+ EXPECT_TRUE(r.contains_key("abc\xff\xff"));
+ EXPECT_FALSE(r.contains_key("abc\0"));
+ EXPECT_FALSE(r.contains_key("abd"));
+}
+
+TEST(MultiModeKeyRangeTest, PrefixSignedWrapAround) {
+ string prefix = "abc\x7f";
+ MultiModeKeyRange r = MultiModeKeyRange::FromPrefix(prefix);
+ EXPECT_EQ(prefix, r.begin_key());
+ EXPECT_EQ("abc\x80", r.end_key());
+
+ EXPECT_TRUE(r.contains_key("abc\x7f\x07"));
+ EXPECT_TRUE(r.contains_key("abc\x7f\x15"));
+ EXPECT_TRUE(r.contains_key("abc\x7f\x61"));
+ EXPECT_TRUE(r.contains_key("abc\x7f\xff"));
+ EXPECT_FALSE(r.contains_key("abc\0"));
+ EXPECT_FALSE(r.contains_key("abc\x01"));
+ EXPECT_FALSE(r.contains_key("abd"));
+ EXPECT_FALSE(r.contains_key("ab\x80"));
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc
new file mode 100644
index 0000000000..a1a63a975a
--- /dev/null
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc
@@ -0,0 +1,200 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
+#include "tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+namespace {
+
+class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel {
+ public:
+ using DatasetOpKernel::DatasetOpKernel;
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
+ string prefix;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "prefix", &prefix));
+
+ string start_key;
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument<string>(ctx, "start_key", &start_key));
+ string end_key;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "end_key", &end_key));
+
+ BigtableTableResource* resource;
+ OP_REQUIRES_OK(ctx,
+ LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
+
+ OP_REQUIRES(ctx, prefix.empty() || start_key.empty(),
+ errors::InvalidArgument(
+ "Only one of prefix and start_key can be provided"));
+ if (!prefix.empty()) {
+ OP_REQUIRES(ctx, end_key.empty(),
+ errors::InvalidArgument(
+ "If prefix is specified, end_key must be empty."));
+ }
+
+ *output = new Dataset(ctx, resource, std::move(prefix),
+ std::move(start_key), std::move(end_key));
+ }
+
+ private:
+ class Dataset : public GraphDatasetBase {
+ public:
+ explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table,
+ string prefix, string start_key, string end_key)
+ : GraphDatasetBase(ctx),
+ table_(table),
+ key_range_(MakeMultiModeKeyRange(
+ std::move(prefix), std::move(start_key), std::move(end_key))) {
+ table_->Ref();
+ }
+
+ ~Dataset() override { table_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(new Iterator(
+ {this, strings::StrCat(prefix, "::BigtableSampleKeyPairsDataset")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ static DataTypeVector* dtypes =
+ new DataTypeVector({DT_STRING, DT_STRING});
+ return *dtypes;
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ static std::vector<PartialTensorShape>* shapes =
+ new std::vector<PartialTensorShape>({{}, {}});
+ return *shapes;
+ }
+
+ string DebugString() const override {
+ return "BigtableSampleKeyPairsDatasetOp::Dataset";
+ }
+
+ private:
+ static MultiModeKeyRange MakeMultiModeKeyRange(string prefix,
+ string start_key,
+ string end_key) {
+ if (!start_key.empty()) {
+ return MultiModeKeyRange::FromRange(std::move(start_key),
+ std::move(end_key));
+ }
+ return MultiModeKeyRange::FromPrefix(std::move(prefix));
+ }
+
+ BigtableTableResource& table() const { return *table_; }
+
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params) {}
+
+ // Computes split points (`keys_`) to use when scanning the table.
+ //
+ // Initialize first retrieves the sample keys from the table (`row_keys`),
+ // as these often form good split points within the table. We then iterate
+ // over them, and copy them to `keys_` if they fall within the requested
+ // range to scan (`dataset()->key_range_`). Because the requested range
+ // might start between elements of the sampled keys list, care is taken to
+ // ensure we don't accidentally miss any subsets of the requested range by
+ // including `begin_key()` and `end_key()` as appropriate.
+ Status Initialize(IteratorContext* ctx) override {
+ grpc::Status status;
+ std::vector<google::cloud::bigtable::RowKeySample> row_keys =
+ dataset()->table().table().SampleRows(status);
+ if (!status.ok()) {
+ return GrpcStatusToTfStatus(status);
+ }
+
+ for (size_t i = 0; i < row_keys.size(); ++i) {
+ string row_key(row_keys[i].row_key);
+ if (dataset()->key_range_.contains_key(row_key)) {
+ // First key: check to see if we need to add the begin_key.
+ if (keys_.empty() && dataset()->key_range_.begin_key() != row_key) {
+ keys_.push_back(dataset()->key_range_.begin_key());
+ }
+ keys_.push_back(std::move(row_key));
+ } else if (!keys_.empty()) {
+ // If !keys_.empty(), then we have found at least one element of
+ // `row_keys` that is within our requested range
+ // (`dataset()->key_range_`). Because `row_keys` is sorted, if we
+ // have found an element that's not within our key range, then we
+ // are after our requested range (ranges are contiguous) and can end
+ // iteration early.
+ break;
+ }
+ }
+
+ // Handle the case where we skip over the selected range entirely.
+ if (keys_.empty()) {
+ keys_.push_back(dataset()->key_range_.begin_key());
+ }
+
+ // Last key: check to see if we need to add the end_key.
+ if (keys_.back() != dataset()->key_range_.end_key()) {
+ keys_.push_back(dataset()->key_range_.end_key());
+ }
+ return Status::OK();
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+ if (index_ > keys_.size() - 2) {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+
+ *end_of_sequence = false;
+ out_tensors->emplace_back(ctx->allocator({}), DT_STRING,
+ TensorShape({}));
+ out_tensors->back().scalar<string>()() = keys_[index_];
+
+ out_tensors->emplace_back(ctx->allocator({}), DT_STRING,
+ TensorShape({}));
+ out_tensors->back().scalar<string>()() = keys_[index_ + 1];
+ ++index_;
+
+ return Status::OK();
+ }
+
+ private:
+ mutex mu_;
+ size_t index_ GUARDED_BY(mu_) = 0;
+ // Note: we store the keys_ on the iterator instead of the dataset
+ // because we want to re-sample the row keys in case there have been
+ // tablet rebalancing operations since the dataset was created.
+ //
+ // Note: keys_ is readonly after Initialize, and thus does not need a
+ // guarding lock.
+ std::vector<string> keys_;
+ };
+
+ BigtableTableResource* const table_;
+ const MultiModeKeyRange key_range_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("BigtableSampleKeyPairsDataset").Device(DEVICE_CPU),
+ BigtableSampleKeyPairsDatasetOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc
new file mode 100644
index 0000000000..a5a47cfe2d
--- /dev/null
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc
@@ -0,0 +1,113 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+namespace {
+
+class BigtableSampleKeysDatasetOp : public DatasetOpKernel {
+ public:
+ using DatasetOpKernel::DatasetOpKernel;
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
+ BigtableTableResource* resource;
+ OP_REQUIRES_OK(ctx,
+ LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
+ *output = new Dataset(ctx, resource);
+ }
+
+ private:
+ class Dataset : public GraphDatasetBase {
+ public:
+ explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table)
+ : GraphDatasetBase(ctx), table_(table) {
+ table_->Ref();
+ }
+
+ ~Dataset() override { table_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(new Iterator(
+ {this, strings::StrCat(prefix, "::BigtableSampleKeysDataset")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ static DataTypeVector* dtypes = new DataTypeVector({DT_STRING});
+ return *dtypes;
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ static std::vector<PartialTensorShape>* shapes =
+ new std::vector<PartialTensorShape>({{}});
+ return *shapes;
+ }
+
+ string DebugString() const override {
+ return "BigtableRangeKeyDatasetOp::Dataset";
+ }
+
+ BigtableTableResource* table() const { return table_; }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ ::grpc::Status status;
+ row_keys_ = dataset()->table()->table().SampleRows(status);
+ if (!status.ok()) {
+ row_keys_.clear();
+ return GrpcStatusToTfStatus(status);
+ }
+ return Status::OK();
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+ if (index_ < row_keys_.size()) {
+ out_tensors->emplace_back(ctx->allocator({}), DT_STRING,
+ TensorShape({}));
+ out_tensors->back().scalar<string>()() =
+ string(row_keys_[index_].row_key);
+ *end_of_sequence = false;
+ index_++;
+ } else {
+ *end_of_sequence = true;
+ }
+ return Status::OK();
+ }
+
+ private:
+ mutex mu_;
+ size_t index_ = 0;
+ std::vector<::google::cloud::bigtable::RowKeySample> row_keys_;
+ };
+
+ BigtableTableResource* const table_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(Name("BigtableSampleKeysDataset").Device(DEVICE_CPU),
+ BigtableSampleKeysDatasetOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc
index c164682508..f083ce6f44 100644
--- a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc
+++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc
@@ -63,24 +63,29 @@ class SampleRowKeysResponse : public grpc::ClientReaderInterface<
bool NextMessageSize(uint32_t* sz) override {
mutex_lock l(mu_);
- if (sent_first_message_) {
- return false;
+ mutex_lock l2(client_->mu_);
+ if (num_messages_sent_ * 2 < client_->table_.rows.size()) {
+ *sz = 10000; // A sufficiently high enough value to not worry about.
+ return true;
}
- *sz = 10000; // A sufficiently high enough value to not worry about.
- return true;
+ return false;
}
bool Read(google::bigtable::v2::SampleRowKeysResponse* resp) override {
+ // Send every other key from the table.
mutex_lock l(mu_);
- if (sent_first_message_) {
- return false;
- }
- sent_first_message_ = true;
-
mutex_lock l2(client_->mu_);
*resp = google::bigtable::v2::SampleRowKeysResponse();
- resp->set_row_key(client_->table_.rows.begin()->first);
- resp->set_offset_bytes(0);
+ auto itr = client_->table_.rows.begin();
+ for (uint64 i = 0; i < 2 * num_messages_sent_; ++i) {
+ ++itr;
+ if (itr == client_->table_.rows.end()) {
+ return false;
+ }
+ }
+ resp->set_row_key(itr->first);
+ resp->set_offset_bytes(100 * num_messages_sent_);
+ num_messages_sent_++;
return true;
}
@@ -90,7 +95,7 @@ class SampleRowKeysResponse : public grpc::ClientReaderInterface<
private:
mutex mu_;
- bool sent_first_message_ GUARDED_BY(mu_) = false;
+ int64 num_messages_sent_ GUARDED_BY(mu_) = 0;
BigtableTestClient* client_; // Not owned.
};
diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_test.cc b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_test.cc
index d6b3964719..32611e2590 100644
--- a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_test.cc
+++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_test.cc
@@ -286,5 +286,60 @@ TEST(BigtableTestClientTest, RowKeys) {
EXPECT_TRUE(rows.Finish().ok());
}
+TEST(BigtableTestClientTest, SampleKeys) {
+ std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr =
+ std::make_shared<BigtableTestClient>();
+ ::google::cloud::bigtable::noex::Table table(client_ptr, "test_table");
+
+ WriteCell("r1", "f1", "c1", "v1", &table);
+ WriteCell("r2", "f1", "c1", "v2", &table);
+ WriteCell("r3", "f1", "c1", "v3", &table);
+ WriteCell("r4", "f1", "c1", "v4", &table);
+ WriteCell("r5", "f1", "c1", "v5", &table);
+
+ grpc::Status status;
+ auto resp = table.SampleRows(status);
+ EXPECT_TRUE(status.ok());
+ EXPECT_EQ(3, resp.size());
+ EXPECT_EQ("r1", string(resp[0].row_key));
+ EXPECT_EQ(0, resp[0].offset_bytes);
+ EXPECT_EQ("r3", string(resp[1].row_key));
+ EXPECT_EQ(100, resp[1].offset_bytes);
+ EXPECT_EQ("r5", string(resp[2].row_key));
+ EXPECT_EQ(200, resp[2].offset_bytes);
+}
+
+TEST(BigtableTestClientTest, SampleKeysShort) {
+ std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr =
+ std::make_shared<BigtableTestClient>();
+ ::google::cloud::bigtable::noex::Table table(client_ptr, "test_table");
+
+ WriteCell("r1", "f1", "c1", "v1", &table);
+
+ grpc::Status status;
+ auto resp = table.SampleRows(status);
+ EXPECT_TRUE(status.ok());
+ EXPECT_EQ(1, resp.size());
+ EXPECT_EQ("r1", string(resp[0].row_key));
+}
+
+TEST(BigtableTestClientTest, SampleKeysEvenNumber) {
+ std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr =
+ std::make_shared<BigtableTestClient>();
+ ::google::cloud::bigtable::noex::Table table(client_ptr, "test_table");
+
+ WriteCell("r1", "f1", "c1", "v1", &table);
+ WriteCell("r2", "f1", "c1", "v2", &table);
+ WriteCell("r3", "f1", "c1", "v3", &table);
+ WriteCell("r4", "f1", "c1", "v4", &table);
+
+ grpc::Status status;
+ auto resp = table.SampleRows(status);
+ EXPECT_TRUE(status.ok());
+ EXPECT_EQ(2, resp.size());
+ EXPECT_EQ("r1", string(resp[0].row_key));
+ EXPECT_EQ("r3", string(resp[1].row_key));
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/ops/bigtable_ops.cc b/tensorflow/contrib/bigtable/ops/bigtable_ops.cc
index c7ff012ec8..416b719e30 100644
--- a/tensorflow/contrib/bigtable/ops/bigtable_ops.cc
+++ b/tensorflow/contrib/bigtable/ops/bigtable_ops.cc
@@ -23,6 +23,7 @@ REGISTER_OP("BigtableClient")
.Attr("project_id: string")
.Attr("instance_id: string")
.Attr("connection_pool_size: int")
+ .Attr("max_receive_message_size: int = -1")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.Output("client: resource")
@@ -71,6 +72,23 @@ REGISTER_OP("BigtableRangeKeyDataset")
// stateful to inhibit constant folding.
.SetShapeFn(shape_inference::ScalarShape);
+REGISTER_OP("BigtableSampleKeysDataset")
+ .Input("table: resource")
+ .Output("handle: variant")
+ .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
+ // stateful to inhibit constant folding.
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("BigtableSampleKeyPairsDataset")
+ .Input("table: resource")
+ .Input("prefix: string")
+ .Input("start_key: string")
+ .Input("end_key: string")
+ .Output("handle: variant")
+ .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
+ // stateful to inhibit constant folding.
+ .SetShapeFn(shape_inference::ScalarShape);
+
// TODO(saeta): Support continuing despite bad data (e.g. empty string, or
// skip incomplete row.)
REGISTER_OP("BigtableScanDataset")
diff --git a/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py b/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py
index d33a66f2df..2f20064619 100644
--- a/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py
+++ b/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py
@@ -21,8 +21,10 @@ from __future__ import print_function
from tensorflow.contrib import bigtable
from tensorflow.contrib.bigtable.ops import gen_bigtable_ops
from tensorflow.contrib.bigtable.ops import gen_bigtable_test_ops
+from tensorflow.contrib.bigtable.python.ops import bigtable_api
from tensorflow.contrib.util import loader
from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
from tensorflow.python.platform import resource_loader
from tensorflow.python.platform import test
from tensorflow.python.util import compat
@@ -31,6 +33,10 @@ _bigtable_so = loader.load_op_library(
resource_loader.get_path_to_datafile("_bigtable_test.so"))
+def _ListOfTuplesOfStringsToBytes(values):
+ return [(compat.as_bytes(i[0]), compat.as_bytes(i[1])) for i in values]
+
+
class BigtableOpsTest(test.TestCase):
COMMON_ROW_KEYS = ["r1", "r2", "r3"]
COMMON_VALUES = ["v1", "v2", "v3"]
@@ -99,12 +105,18 @@ class BigtableOpsTest(test.TestCase):
def testScanPrefixListCol(self):
self.runScanTest(self._table.scan_prefix("r", cf1=["c1"]))
+ def testScanPrefixTupleCol(self):
+ self.runScanTest(self._table.scan_prefix("r", columns=("cf1", "c1")))
+
def testScanRangeStringCol(self):
self.runScanTest(self._table.scan_range("r1", "r4", cf1="c1"))
def testScanRangeListCol(self):
self.runScanTest(self._table.scan_range("r1", "r4", cf1=["c1"]))
+ def testScanRangeTupleCol(self):
+ self.runScanTest(self._table.scan_range("r1", "r4", columns=("cf1", "c1")))
+
def testLookup(self):
ds = self._table.keys_by_prefix_dataset("r")
ds = ds.apply(self._table.lookup_columns(cf1="c1"))
@@ -127,6 +139,134 @@ class BigtableOpsTest(test.TestCase):
"Unequal values at step %d: want: %s, got: %s" %
(i, compat.as_bytes(elem[1]), compat.as_bytes(output[1])))
+ def testSampleKeys(self):
+ ds = self._table.sample_keys()
+ itr = ds.make_initializable_iterator()
+ n = itr.get_next()
+ expected_key = self.COMMON_ROW_KEYS[0]
+ with self.test_session() as sess:
+ self._writeCommonValues(sess)
+ sess.run(itr.initializer)
+ output = sess.run(n)
+ self.assertEqual(
+ compat.as_bytes(self.COMMON_ROW_KEYS[0]), compat.as_bytes(output),
+ "Unequal keys: want: %s, got: %s" % (compat.as_bytes(
+ self.COMMON_ROW_KEYS[0]), compat.as_bytes(output)))
+ output = sess.run(n)
+ self.assertEqual(
+ compat.as_bytes(self.COMMON_ROW_KEYS[2]), compat.as_bytes(output),
+ "Unequal keys: want: %s, got: %s" % (compat.as_bytes(
+ self.COMMON_ROW_KEYS[2]), compat.as_bytes(output)))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(n)
+
+ def runSampleKeyPairsTest(self, ds, expected_key_pairs):
+ itr = ds.make_initializable_iterator()
+ n = itr.get_next()
+ with self.test_session() as sess:
+ self._writeCommonValues(sess)
+ sess.run(itr.initializer)
+ for i, elems in enumerate(expected_key_pairs):
+ output = sess.run(n)
+ self.assertEqual(
+ compat.as_bytes(elems[0]), compat.as_bytes(output[0]),
+ "Unequal key pair (first element) at step %d; want: %s, got %s" %
+ (i, compat.as_bytes(elems[0]), compat.as_bytes(output[0])))
+ self.assertEqual(
+ compat.as_bytes(elems[1]), compat.as_bytes(output[1]),
+ "Unequal key pair (second element) at step %d; want: %s, got %s" %
+ (i, compat.as_bytes(elems[1]), compat.as_bytes(output[1])))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(n)
+
+ def testSampleKeyPairsSimplePrefix(self):
+ ds = bigtable_api._BigtableSampleKeyPairsDataset(
+ self._table, prefix="r", start="", end="")
+ expected_key_pairs = [("r", "r1"), ("r1", "r3"), ("r3", "s")]
+ self.runSampleKeyPairsTest(ds, expected_key_pairs)
+
+ def testSampleKeyPairsSimpleRange(self):
+ ds = bigtable_api._BigtableSampleKeyPairsDataset(
+ self._table, prefix="", start="r1", end="r3")
+ expected_key_pairs = [("r1", "r3")]
+ self.runSampleKeyPairsTest(ds, expected_key_pairs)
+
+ def testSampleKeyPairsSkipRangePrefix(self):
+ ds = bigtable_api._BigtableSampleKeyPairsDataset(
+ self._table, prefix="r2", start="", end="")
+ expected_key_pairs = [("r2", "r3")]
+ self.runSampleKeyPairsTest(ds, expected_key_pairs)
+
+ def testSampleKeyPairsSkipRangeRange(self):
+ ds = bigtable_api._BigtableSampleKeyPairsDataset(
+ self._table, prefix="", start="r2", end="r3")
+ expected_key_pairs = [("r2", "r3")]
+ self.runSampleKeyPairsTest(ds, expected_key_pairs)
+
+ def testSampleKeyPairsOffsetRanges(self):
+ ds = bigtable_api._BigtableSampleKeyPairsDataset(
+ self._table, prefix="", start="r2", end="r4")
+ expected_key_pairs = [("r2", "r3"), ("r3", "r4")]
+ self.runSampleKeyPairsTest(ds, expected_key_pairs)
+
+ def testSampleKeyPairEverything(self):
+ ds = bigtable_api._BigtableSampleKeyPairsDataset(
+ self._table, prefix="", start="", end="")
+ expected_key_pairs = [("", "r1"), ("r1", "r3"), ("r3", "")]
+ self.runSampleKeyPairsTest(ds, expected_key_pairs)
+
+ def testSampleKeyPairsPrefixAndStartKey(self):
+ ds = bigtable_api._BigtableSampleKeyPairsDataset(
+ self._table, prefix="r", start="r1", end="")
+ itr = ds.make_initializable_iterator()
+ with self.test_session() as sess:
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(itr.initializer)
+
+ def testSampleKeyPairsPrefixAndEndKey(self):
+ ds = bigtable_api._BigtableSampleKeyPairsDataset(
+ self._table, prefix="r", start="", end="r3")
+ itr = ds.make_initializable_iterator()
+ with self.test_session() as sess:
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(itr.initializer)
+
+ def testParallelScanPrefix(self):
+ ds = self._table.parallel_scan_prefix(prefix="r", cf1="c1")
+ itr = ds.make_initializable_iterator()
+ n = itr.get_next()
+ with self.test_session() as sess:
+ self._writeCommonValues(sess)
+ sess.run(itr.initializer)
+ expected_values = list(zip(self.COMMON_ROW_KEYS, self.COMMON_VALUES))
+ actual_values = []
+ for _ in range(len(expected_values)):
+ output = sess.run(n)
+ actual_values.append(output)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(n)
+ self.assertItemsEqual(
+ _ListOfTuplesOfStringsToBytes(expected_values),
+ _ListOfTuplesOfStringsToBytes(actual_values))
+
+ def testParallelScanRange(self):
+ ds = self._table.parallel_scan_range(start="r1", end="r4", cf1="c1")
+ itr = ds.make_initializable_iterator()
+ n = itr.get_next()
+ with self.test_session() as sess:
+ self._writeCommonValues(sess)
+ sess.run(itr.initializer)
+ expected_values = list(zip(self.COMMON_ROW_KEYS, self.COMMON_VALUES))
+ actual_values = []
+ for _ in range(len(expected_values)):
+ output = sess.run(n)
+ actual_values.append(output)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(n)
+ self.assertItemsEqual(
+ _ListOfTuplesOfStringsToBytes(expected_values),
+ _ListOfTuplesOfStringsToBytes(actual_values))
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py
index 39c58ba665..9f73b7223c 100644
--- a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py
+++ b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py
@@ -28,8 +28,10 @@ from __future__ import division
from __future__ import print_function
from six import iteritems
+from six import string_types
from tensorflow.contrib.bigtable.ops import gen_bigtable_ops
+from tensorflow.contrib.data.python.ops import interleave_ops
from tensorflow.contrib.util import loader
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
@@ -49,7 +51,11 @@ class BigtableClient(object):
`table` method to open a Bigtable Table.
"""
- def __init__(self, project_id, instance_id, connection_pool_size=None):
+ def __init__(self,
+ project_id,
+ instance_id,
+ connection_pool_size=None,
+ max_receive_message_size=None):
"""Creates a BigtableClient that can be used to open connections to tables.
Args:
@@ -57,6 +63,8 @@ class BigtableClient(object):
instance_id: A string representing the Bigtable instance to connect to.
connection_pool_size: (Optional.) A number representing the number of
concurrent connections to the Cloud Bigtable service to make.
+ max_receive_message_size: (Optional.) The maximum bytes received in a
+ single gRPC response.
Raises:
ValueError: if the arguments are invalid (e.g. wrong type, or out of
@@ -74,10 +82,16 @@ class BigtableClient(object):
connection_pool_size = -1
elif connection_pool_size < 1:
raise ValueError("`connection_pool_size` must be positive")
+
+ if max_receive_message_size is None:
+ max_receive_message_size = -1
+ elif max_receive_message_size < 1:
+ raise ValueError("`max_receive_message_size` must be positive")
+
self._connection_pool_size = connection_pool_size
- self._resource = gen_bigtable_ops.bigtable_client(project_id, instance_id,
- connection_pool_size)
+ self._resource = gen_bigtable_ops.bigtable_client(
+ project_id, instance_id, connection_pool_size, max_receive_message_size)
def table(self, name, snapshot=None):
"""Opens a table and returns a `BigTable` object.
@@ -205,6 +219,18 @@ class BigTable(object):
"""
return _BigtablePrefixKeyDataset(self, prefix)
+ def sample_keys(self):
+ """Retrieves a sampling of row keys from the Bigtable table.
+
+ This dataset is most often used in conjunction with
+ @{tf.contrib.data.parallel_interleave} to construct a set of ranges for
+ scanning in parallel.
+
+ Returns:
+ A @{tf.data.Dataset} returning string row keys.
+ """
+ return _BigtableSampleKeysDataset(self)
+
def scan_prefix(self, prefix, probability=None, columns=None, **kwargs):
"""Retrieves row (including values) from the Bigtable service.
@@ -227,9 +253,11 @@ class BigTable(object):
Note: only the latest value of a cell will be retrieved.
Args:
- prefix: The prefix all row keys muat match to be retrieved for prefix-
+ prefix: The prefix all row keys must match to be retrieved for prefix-
based scans.
- probability: Probabilistically sample rows.
+ probability: (Optional.) A float between 0 (exclusive) and 1 (inclusive).
+ A non-1 value indicates to probabilistically sample rows with the
+ provided probability.
columns: The columns to read. Note: most commonly, they are expressed as
kwargs. Use the columns value if you are using column families that are
reserved. The value of columns and kwargs are merged. Columns is a list
@@ -244,26 +272,8 @@ class BigTable(object):
Raises:
ValueError: If the configured probability is unexpected.
"""
- if probability is None:
- probability = 1.0
- if isinstance(probability, float) and (probability <= 0.0 or
- probability > 1.0):
- raise ValueError("probability must be in the range (0, 1].")
-
- normalized = columns
- if normalized is None:
- normalized = []
- if isinstance(normalized, tuple):
- normalized = list(normalized)
- for key, value in iteritems(kwargs):
- if key == "name":
- continue
- if isinstance(value, str):
- normalized.append((key, value))
- continue
- for col in value:
- normalized.append((key, col))
-
+ probability = _normalize_probability(probability)
+ normalized = _normalize_columns(columns, kwargs)
return _BigtableScanDataset(self, prefix, "", "", normalized, probability)
def scan_range(self, start, end, probability=None, columns=None, **kwargs):
@@ -290,7 +300,9 @@ class BigTable(object):
Args:
start: The start of the range when scanning by range.
end: (Optional.) The end of the range when scanning by range.
- probability: Probabilistically sample rows.
+ probability: (Optional.) A float between 0 (exclusive) and 1 (inclusive).
+ A non-1 value indicates to probabilistically sample rows with the
+ provided probability.
columns: The columns to read. Note: most commonly, they are expressed as
kwargs. Use the columns value if you are using column families that are
reserved. The value of columns and kwargs are merged. Columns is a list
@@ -305,27 +317,129 @@ class BigTable(object):
Raises:
ValueError: If the configured probability is unexpected.
"""
- if probability is None:
- probability = 1.0
- if isinstance(probability, float) and (probability <= 0.0 or
- probability > 1.0):
- raise ValueError("probability must be in the range (0, 1].")
+ probability = _normalize_probability(probability)
+ normalized = _normalize_columns(columns, kwargs)
+ return _BigtableScanDataset(self, "", start, end, normalized, probability)
- normalized = columns
- if normalized is None:
- normalized = []
- if isinstance(normalized, tuple):
- normalized = list(normalized)
- for key, value in iteritems(kwargs):
- if key == "name":
- continue
- if isinstance(value, str):
- normalized.append((key, value))
- continue
- for col in value:
- normalized.append((key, col))
+ def parallel_scan_prefix(self,
+ prefix,
+ num_parallel_scans=None,
+ probability=None,
+ columns=None,
+ **kwargs):
+ """Retrieves row (including values) from the Bigtable service at high speed.
- return _BigtableScanDataset(self, "", start, end, normalized, probability)
+ Rows with row-key prefixed by `prefix` will be retrieved. This method is
+ similar to `scan_prefix`, but by constrast performs multiple sub-scans in
+ parallel in order to achieve higher performance.
+
+ Note: The dataset produced by this method is not deterministic!
+
+ Specifying the columns to retrieve for each row is done by either using
+ kwargs or in the columns parameter. To retrieve values of the columns "c1",
+ and "c2" from the column family "cfa", and the value of the column "c3"
+ from column family "cfb", the following datasets (`ds1`, and `ds2`) are
+ equivalent:
+
+ ```
+ table = # ...
+ ds1 = table.parallel_scan_prefix("row_prefix", columns=[("cfa", "c1"),
+ ("cfa", "c2"),
+ ("cfb", "c3")])
+ ds2 = table.parallel_scan_prefix("row_prefix", cfa=["c1", "c2"], cfb="c3")
+ ```
+
+ Note: only the latest value of a cell will be retrieved.
+
+ Args:
+ prefix: The prefix all row keys must match to be retrieved for prefix-
+ based scans.
+ num_parallel_scans: (Optional.) The number of concurrent scans against the
+ Cloud Bigtable instance.
+ probability: (Optional.) A float between 0 (exclusive) and 1 (inclusive).
+ A non-1 value indicates to probabilistically sample rows with the
+ provided probability.
+ columns: The columns to read. Note: most commonly, they are expressed as
+ kwargs. Use the columns value if you are using column families that are
+ reserved. The value of columns and kwargs are merged. Columns is a list
+ of tuples of strings ("column_family", "column_qualifier").
+ **kwargs: The column families and columns to read. Keys are treated as
+ column_families, and values can be either lists of strings, or strings
+ that are treated as the column qualifier (column name).
+
+ Returns:
+ A @{tf.data.Dataset} returning the row keys and the cell contents.
+
+ Raises:
+ ValueError: If the configured probability is unexpected.
+ """
+ probability = _normalize_probability(probability)
+ normalized = _normalize_columns(columns, kwargs)
+ ds = _BigtableSampleKeyPairsDataset(self, prefix, "", "")
+ return self._make_parallel_scan_dataset(ds, num_parallel_scans, probability,
+ normalized)
+
+ def parallel_scan_range(self,
+ start,
+ end,
+ num_parallel_scans=None,
+ probability=None,
+ columns=None,
+ **kwargs):
+ """Retrieves rows (including values) from the Bigtable service.
+
+ Rows with row-keys between `start` and `end` will be retrieved. This method
+ is similar to `scan_range`, but by constrast performs multiple sub-scans in
+ parallel in order to achieve higher performance.
+
+ Note: The dataset produced by this method is not deterministic!
+
+ Specifying the columns to retrieve for each row is done by either using
+ kwargs or in the columns parameter. To retrieve values of the columns "c1",
+ and "c2" from the column family "cfa", and the value of the column "c3"
+ from column family "cfb", the following datasets (`ds1`, and `ds2`) are
+ equivalent:
+
+ ```
+ table = # ...
+ ds1 = table.parallel_scan_range("row_start",
+ "row_end",
+ columns=[("cfa", "c1"),
+ ("cfa", "c2"),
+ ("cfb", "c3")])
+ ds2 = table.parallel_scan_range("row_start", "row_end",
+ cfa=["c1", "c2"], cfb="c3")
+ ```
+
+ Note: only the latest value of a cell will be retrieved.
+
+ Args:
+ start: The start of the range when scanning by range.
+ end: (Optional.) The end of the range when scanning by range.
+ num_parallel_scans: (Optional.) The number of concurrent scans against the
+ Cloud Bigtable instance.
+ probability: (Optional.) A float between 0 (exclusive) and 1 (inclusive).
+ A non-1 value indicates to probabilistically sample rows with the
+ provided probability.
+ columns: The columns to read. Note: most commonly, they are expressed as
+ kwargs. Use the columns value if you are using column families that are
+ reserved. The value of columns and kwargs are merged. Columns is a list
+ of tuples of strings ("column_family", "column_qualifier").
+ **kwargs: The column families and columns to read. Keys are treated as
+ column_families, and values can be either lists of strings, or strings
+ that are treated as the column qualifier (column name).
+
+ Returns:
+ A @{tf.data.Dataset} returning the row keys and the cell contents.
+
+ Raises:
+ ValueError: If the configured probability is unexpected.
+ """
+ probability = _normalize_probability(probability)
+ normalized = _normalize_columns(columns, kwargs)
+ ds = _BigtableSampleKeyPairsDataset(self, "", start, end)
+ return self._make_parallel_scan_dataset(ds, num_parallel_scans, probability,
+ normalized)
def write(self, dataset, column_families, columns, timestamp=None):
"""Writes a dataset to the table.
@@ -372,6 +486,89 @@ class BigTable(object):
columns,
timestamp)
+ def _make_parallel_scan_dataset(self, ds, num_parallel_scans,
+ normalized_probability, normalized_columns):
+ """Builds a parallel dataset from a given range.
+
+ Args:
+ ds: A `_BigtableSampleKeyPairsDataset` returning ranges of keys to use.
+ num_parallel_scans: The number of concurrent parallel scans to use.
+ normalized_probability: A number between 0 and 1 for the keep probability.
+ normalized_columns: The column families and column qualifiers to retrieve.
+
+ Returns:
+ A @{tf.data.Dataset} representing the result of the parallel scan.
+ """
+ if num_parallel_scans is None:
+ num_parallel_scans = 50
+
+ ds = ds.shuffle(buffer_size=10000) # TODO(saeta): Make configurable.
+
+ def _interleave_fn(start, end):
+ return _BigtableScanDataset(
+ self,
+ prefix="",
+ start=start,
+ end=end,
+ normalized=normalized_columns,
+ probability=normalized_probability)
+
+ # Note prefetch_input_elements must be set in order to avoid rpc timeouts.
+ ds = ds.apply(
+ interleave_ops.parallel_interleave(
+ _interleave_fn,
+ cycle_length=num_parallel_scans,
+ sloppy=True,
+ prefetch_input_elements=1))
+ return ds
+
+
+def _normalize_probability(probability):
+ if probability is None:
+ probability = 1.0
+ if isinstance(probability, float) and (probability <= 0.0 or
+ probability > 1.0):
+ raise ValueError("probability must be in the range (0, 1].")
+ return probability
+
+
+def _normalize_columns(columns, provided_kwargs):
+ """Converts arguments (columns, and kwargs dict) to C++ representation.
+
+ Args:
+ columns: a datastructure containing the column families and qualifier to
+ retrieve. Valid types include (1) None, (2) list of tuples, (3) a tuple of
+ strings.
+ provided_kwargs: a dictionary containing the column families and qualifiers
+ to retrieve
+
+ Returns:
+ A list of pairs of column family+qualifier to retrieve.
+
+ Raises:
+ ValueError: If there are no cells to retrieve or the columns are in an
+ incorrect format.
+ """
+ normalized = columns
+ if normalized is None:
+ normalized = []
+ if isinstance(normalized, tuple):
+ if len(normalized) == 2:
+ normalized = [normalized]
+ else:
+ raise ValueError("columns was a tuple of inappropriate length")
+ for key, value in iteritems(provided_kwargs):
+ if key == "name":
+ continue
+ if isinstance(value, string_types):
+ normalized.append((key, value))
+ continue
+ for col in value:
+ normalized.append((key, col))
+ if not normalized:
+ raise ValueError("At least one column + column family must be specified.")
+ return normalized
+
class _BigtableKeyDataset(dataset_ops.Dataset):
"""_BigtableKeyDataset is an abstract class representing the keys of a table.
@@ -429,6 +626,20 @@ class _BigtableRangeKeyDataset(_BigtableKeyDataset):
end_key=self._end)
+class _BigtableSampleKeysDataset(_BigtableKeyDataset):
+ """_BigtableSampleKeysDataset represents a sampling of row keys.
+ """
+
+ # TODO(saeta): Expose the data size offsets into the keys.
+
+ def __init__(self, table):
+ super(_BigtableSampleKeysDataset, self).__init__(table)
+
+ def _as_variant_tensor(self):
+ return gen_bigtable_ops.bigtable_sample_keys_dataset(
+ table=self._table._resource) # pylint: disable=protected-access
+
+
class _BigtableLookupDataset(dataset_ops.Dataset):
"""_BigtableLookupDataset represents a dataset that retrieves values for keys.
"""
@@ -497,3 +708,34 @@ class _BigtableScanDataset(dataset_ops.Dataset):
column_families=self._column_families,
columns=self._columns,
probability=self._probability)
+
+
+class _BigtableSampleKeyPairsDataset(dataset_ops.Dataset):
+ """_BigtableKeyRangeDataset returns key pairs from the Bigtable.
+ """
+
+ def __init__(self, table, prefix, start, end):
+ self._table = table
+ self._prefix = prefix
+ self._start = start
+ self._end = end
+
+ @property
+ def output_classes(self):
+ return (ops.Tensor, ops.Tensor)
+
+ @property
+ def output_shapes(self):
+ return (tensor_shape.TensorShape([]), tensor_shape.TensorShape([]))
+
+ @property
+ def output_types(self):
+ return (dtypes.string, dtypes.string)
+
+ def _as_variant_tensor(self):
+ # pylint: disable=protected-access
+ return gen_bigtable_ops.bigtable_sample_key_pairs_dataset(
+ table=self._table._resource,
+ prefix=self._prefix,
+ start_key=self._start,
+ end_key=self._end)
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
index 9c36c30221..59a78515c6 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
@@ -269,3 +269,88 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator):
model_dir=model_dir,
config=config,
feature_engineering_fn=feature_engineering_fn)
+
+
+class GradientBoostedDecisionTreeRanker(estimator.Estimator):
+ """A ranking estimator using gradient boosted decision trees."""
+
+ def __init__(
+ self,
+ learner_config,
+ examples_per_layer,
+ head,
+ ranking_model_pair_keys,
+ num_trees=None,
+ feature_columns=None,
+ weight_column_name=None,
+ model_dir=None,
+ config=None,
+ label_keys=None,
+ feature_engineering_fn=None,
+ logits_modifier_function=None,
+ center_bias=False,
+ use_core_libs=False,
+ output_leaf_index=False,
+ ):
+ """Initializes a GradientBoostedDecisionTreeRanker instance.
+
+ This is an estimator that can be trained off the pairwise data and can be
+ used for inference on non-paired data. This is essentially LambdaMart.
+ Args:
+ learner_config: A config for the learner.
+ examples_per_layer: Number of examples to accumulate before growing a
+ layer. It can also be a function that computes the number of examples
+ based on the depth of the layer that's being built.
+ head: `Head` instance.
+ ranking_model_pair_keys: Keys to distinguish between features
+ for left and right part of the training pairs for ranking. For example,
+ for an Example with features "a.f1" and "b.f1", the keys would be
+ ("a", "b").
+ num_trees: An int, number of trees to build.
+ feature_columns: A list of feature columns.
+ weight_column_name: Name of the column for weights, or None if not
+ weighted.
+ model_dir: Directory for model exports, etc.
+ config: `RunConfig` object to configure the runtime settings.
+ label_keys: Optional list of strings with size `[n_classes]` defining the
+ label vocabulary. Only supported for `n_classes` > 2.
+ feature_engineering_fn: Feature engineering function. Takes features and
+ labels which are the output of `input_fn` and returns features and
+ labels which will be fed into the model.
+ logits_modifier_function: A modifier function for the logits.
+ center_bias: Whether a separate tree should be created for first fitting
+ the bias.
+ use_core_libs: Whether feature columns and loss are from the core (as
+ opposed to contrib) version of tensorflow.
+ output_leaf_index: whether to output leaf indices along with predictions
+ during inference. The leaf node indexes are available in predictions
+ dict by the key 'leaf_index'. It is a Tensor of rank 2 and its shape is
+ [batch_size, num_trees].
+ For example,
+ result_iter = classifier.predict(...)
+ for result_dict in result_iter:
+ # access leaf index list by result_dict["leaf_index"]
+ # which contains one leaf index per tree
+
+ Raises:
+ ValueError: If learner_config is not valid.
+ """
+ super(GradientBoostedDecisionTreeRanker, self).__init__(
+ model_fn=model.ranking_model_builder,
+ params={
+ 'head': head,
+ 'n_classes': 2,
+ 'feature_columns': feature_columns,
+ 'learner_config': learner_config,
+ 'num_trees': num_trees,
+ 'weight_column_name': weight_column_name,
+ 'examples_per_layer': examples_per_layer,
+ 'center_bias': center_bias,
+ 'logits_modifier_function': logits_modifier_function,
+ 'use_core_libs': use_core_libs,
+ 'output_leaf_index': output_leaf_index,
+ 'ranking_model_pair_keys': ranking_model_pair_keys,
+ },
+ model_dir=model_dir,
+ config=config,
+ feature_engineering_fn=feature_engineering_fn)
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
index 75ef1b0500..2c2dcb039d 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
@@ -37,12 +37,31 @@ def _train_input_fn():
return features, label
+def _ranking_train_input_fn():
+ features = {
+ "a.f1": constant_op.constant([[3.], [0.3], [1.]]),
+ "a.f2": constant_op.constant([[0.1], [3.], [1.]]),
+ "b.f1": constant_op.constant([[13.], [0.4], [5.]]),
+ "b.f2": constant_op.constant([[1.], [3.], [0.01]]),
+ }
+ label = constant_op.constant([[0], [0], [1]], dtype=dtypes.int32)
+ return features, label
+
+
def _eval_input_fn():
features = {"x": constant_op.constant([[1.], [2.], [2.]])}
label = constant_op.constant([[0], [1], [1]], dtype=dtypes.int32)
return features, label
+def _infer_ranking_train_input_fn():
+ features = {
+ "f1": constant_op.constant([[3.], [2], [1.]]),
+ "f2": constant_op.constant([[0.1], [3.], [1.]])
+ }
+ return features, None
+
+
class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase):
def setUp(self):
@@ -155,6 +174,34 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase):
regressor.evaluate(input_fn=_eval_input_fn, steps=1)
regressor.export(self._export_dir_base)
+ def testRankingDontThrowExceptionForForEstimator(self):
+ learner_config = learner_pb2.LearnerConfig()
+ learner_config.num_classes = 2
+ learner_config.constraints.max_tree_depth = 1
+ model_dir = tempfile.mkdtemp()
+ config = run_config.RunConfig()
+
+ head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
+ loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE)
+
+ model = estimator.GradientBoostedDecisionTreeRanker(
+ head=head_fn,
+ learner_config=learner_config,
+ num_trees=1,
+ examples_per_layer=3,
+ model_dir=model_dir,
+ config=config,
+ use_core_libs=True,
+ feature_columns=[
+ core_feature_column.numeric_column("f1"),
+ core_feature_column.numeric_column("f2")
+ ],
+ ranking_model_pair_keys=("a", "b"))
+
+ model.fit(input_fn=_ranking_train_input_fn, steps=1000)
+ model.evaluate(input_fn=_ranking_train_input_fn, steps=1)
+ model.predict(input_fn=_infer_ranking_train_input_fn)
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/model.py b/tensorflow/contrib/boosted_trees/estimator_batch/model.py
index 1ee8911989..0e8a56e6e9 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/model.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/model.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import copy
+from tensorflow.contrib import learn
from tensorflow.contrib.boosted_trees.estimator_batch import estimator_utils
from tensorflow.contrib.boosted_trees.estimator_batch import trainer_hooks
from tensorflow.contrib.boosted_trees.python.ops import model_ops
@@ -28,7 +29,6 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import state_ops
from tensorflow.python.training import training_util
-
def model_builder(features, labels, mode, params, config):
"""Multi-machine batch gradient descent tree model.
@@ -141,3 +141,184 @@ def model_builder(features, labels, mode, params, config):
trainer_hooks.StopAfterNTrees(num_trees, attempted_trees,
finalized_trees))
return model_fn_ops
+
+
+def ranking_model_builder(features, labels, mode, params, config):
+ """Multi-machine batch gradient descent tree model for ranking.
+
+ Args:
+ features: `Tensor` or `dict` of `Tensor` objects.
+ labels: Labels used to train on.
+ mode: Mode we are in. (TRAIN/EVAL/INFER)
+ params: A dict of hyperparameters.
+ The following hyperparameters are expected:
+ * head: A `Head` instance.
+ * learner_config: A config for the learner.
+ * feature_columns: An iterable containing all the feature columns used by
+ the model.
+ * examples_per_layer: Number of examples to accumulate before growing a
+ layer. It can also be a function that computes the number of examples
+ based on the depth of the layer that's being built.
+ * weight_column_name: The name of weight column.
+ * center_bias: Whether a separate tree should be created for first fitting
+ the bias.
+ * ranking_model_pair_keys (Optional): Keys to distinguish between features
+ for left and right part of the training pairs for ranking. For example,
+ for an Example with features "a.f1" and "b.f1", the keys would be
+ ("a", "b").
+ config: `RunConfig` of the estimator.
+
+ Returns:
+ A `ModelFnOps` object.
+ Raises:
+ ValueError: if inputs are not valid.
+ """
+ head = params["head"]
+ learner_config = params["learner_config"]
+ examples_per_layer = params["examples_per_layer"]
+ feature_columns = params["feature_columns"]
+ weight_column_name = params["weight_column_name"]
+ num_trees = params["num_trees"]
+ use_core_libs = params["use_core_libs"]
+ logits_modifier_function = params["logits_modifier_function"]
+ output_leaf_index = params["output_leaf_index"]
+ ranking_model_pair_keys = params["ranking_model_pair_keys"]
+
+ if features is None:
+ raise ValueError("At least one feature must be specified.")
+
+ if config is None:
+ raise ValueError("Missing estimator RunConfig.")
+
+ center_bias = params["center_bias"]
+
+ if isinstance(features, ops.Tensor):
+ features = {features.name: features}
+
+ # Make a shallow copy of features to ensure downstream usage
+ # is unaffected by modifications in the model function.
+ training_features = copy.copy(features)
+ training_features.pop(weight_column_name, None)
+ global_step = training_util.get_global_step()
+ with ops.device(global_step.device):
+ ensemble_handle = model_ops.tree_ensemble_variable(
+ stamp_token=0,
+ tree_ensemble_config="", # Initialize an empty ensemble.
+ name="ensemble_model")
+
+ # Extract the features.
+ if mode == learn.ModeKeys.TRAIN or mode == learn.ModeKeys.EVAL:
+ # For ranking pairwise training, we extract two sets of features.
+ if len(ranking_model_pair_keys) != 2:
+ raise ValueError("You must provide keys for ranking.")
+ left_pair_key = ranking_model_pair_keys[0]
+ right_pair_key = ranking_model_pair_keys[1]
+ if left_pair_key is None or right_pair_key is None:
+ raise ValueError("Both pair keys should be provided for ranking.")
+
+ features_1 = {}
+ features_2 = {}
+ for name in training_features:
+ feature = training_features[name]
+ new_name = name[2:]
+ if name.startswith(left_pair_key + "."):
+ features_1[new_name] = feature
+ else:
+ assert name.startswith(right_pair_key + ".")
+ features_2[new_name] = feature
+
+ main_features = features_1
+ supplementary_features = features_2
+ else:
+ # For non-ranking or inference ranking, we have only 1 set of features.
+ main_features = training_features
+
+ # Create GBDT model.
+ gbdt_model_main = gbdt_batch.GradientBoostedDecisionTreeModel(
+ is_chief=config.is_chief,
+ num_ps_replicas=config.num_ps_replicas,
+ ensemble_handle=ensemble_handle,
+ center_bias=center_bias,
+ examples_per_layer=examples_per_layer,
+ learner_config=learner_config,
+ feature_columns=feature_columns,
+ logits_dimension=head.logits_dimension,
+ features=main_features,
+ use_core_columns=use_core_libs,
+ output_leaf_index=output_leaf_index)
+
+ with ops.name_scope("gbdt", "gbdt_optimizer"):
+ # Logits for inference.
+ if mode == learn.ModeKeys.INFER:
+ predictions_dict = gbdt_model_main.predict(mode)
+ logits = predictions_dict[gbdt_batch.PREDICTIONS]
+ if logits_modifier_function:
+ logits = logits_modifier_function(logits, features, mode)
+ else:
+ gbdt_model_supplementary = gbdt_batch.GradientBoostedDecisionTreeModel(
+ is_chief=config.is_chief,
+ num_ps_replicas=config.num_ps_replicas,
+ ensemble_handle=ensemble_handle,
+ center_bias=center_bias,
+ examples_per_layer=examples_per_layer,
+ learner_config=learner_config,
+ feature_columns=feature_columns,
+ logits_dimension=head.logits_dimension,
+ features=supplementary_features,
+ use_core_columns=use_core_libs,
+ output_leaf_index=output_leaf_index)
+
+ # Logits for train and eval.
+ if not supplementary_features:
+ raise ValueError("Features for ranking must be specified.")
+
+ predictions_dict_1 = gbdt_model_main.predict(mode)
+ predictions_1 = predictions_dict_1[gbdt_batch.PREDICTIONS]
+
+ predictions_dict_2 = gbdt_model_supplementary.predict(mode)
+ predictions_2 = predictions_dict_2[gbdt_batch.PREDICTIONS]
+
+ logits = predictions_1 - predictions_2
+ if logits_modifier_function:
+ logits = logits_modifier_function(logits, features, mode)
+
+ predictions_dict = predictions_dict_1
+ predictions_dict[gbdt_batch.PREDICTIONS] = logits
+
+ def _train_op_fn(loss):
+ """Returns the op to optimize the loss."""
+ update_op = gbdt_model_main.train(loss, predictions_dict, labels)
+ with ops.control_dependencies(
+ [update_op]), (ops.colocate_with(global_step)):
+ update_op = state_ops.assign_add(global_step, 1).op
+ return update_op
+
+ create_estimator_spec_op = getattr(head, "create_estimator_spec", None)
+ if use_core_libs and callable(create_estimator_spec_op):
+ model_fn_ops = head.create_estimator_spec(
+ features=features,
+ mode=mode,
+ labels=labels,
+ train_op_fn=_train_op_fn,
+ logits=logits)
+ model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops(model_fn_ops)
+ else:
+ model_fn_ops = head.create_model_fn_ops(
+ features=features,
+ mode=mode,
+ labels=labels,
+ train_op_fn=_train_op_fn,
+ logits=logits)
+
+ if output_leaf_index and gbdt_batch.LEAF_INDEX in predictions_dict:
+ model_fn_ops.predictions[gbdt_batch.LEAF_INDEX] = predictions_dict[
+ gbdt_batch.LEAF_INDEX]
+ if num_trees:
+ if center_bias:
+ num_trees += 1
+ finalized_trees, attempted_trees = (
+ gbdt_model_main.get_number_of_trees_tensor())
+ model_fn_ops.training_hooks.append(
+ trainer_hooks.StopAfterNTrees(num_trees, attempted_trees,
+ finalized_trees))
+ return model_fn_ops
diff --git a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h
index a7e7bfc13c..69bb8fd4ad 100644
--- a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h
+++ b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h
@@ -51,7 +51,7 @@ class WeightedQuantilesSummary {
SummaryEntry() {
memset(this, 0, sizeof(*this));
- value = 0;
+ value = ValueType();
weight = 0;
min_rank = 0;
max_rank = 0;
diff --git a/tensorflow/contrib/boosted_trees/python/utils/losses.py b/tensorflow/contrib/boosted_trees/python/utils/losses.py
index ab7ac2aba6..b5ebaf1999 100644
--- a/tensorflow/contrib/boosted_trees/python/utils/losses.py
+++ b/tensorflow/contrib/boosted_trees/python/utils/losses.py
@@ -23,6 +23,12 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
+from tensorflow.python.ops.losses import losses
+
+
+def per_example_squared_hinge_loss(labels, weights, predictions):
+ loss = losses.hinge_loss(labels=labels, logits=predictions, weights=weights)
+ return math_ops.square(loss), control_flow_ops.no_op()
def per_example_logistic_loss(labels, weights, predictions):
@@ -126,7 +132,7 @@ def per_example_squared_loss(labels, weights, predictions):
def per_example_exp_loss(labels, weights, predictions, name=None, eps=0.1):
- """Exponential loss given labels, example weights and predictions.
+ """Trimmed exponential loss given labels, example weights and predictions.
Note that this is only for binary classification.
If logistic loss tries to make sure that the classifier is certain of its
@@ -211,3 +217,62 @@ def per_example_exp_loss(labels, weights, predictions, name=None, eps=0.1):
unweighted_loss = exp_with_logits(
name=name, eps=eps, labels=labels, logits=predictions)
return unweighted_loss * weights, control_flow_ops.no_op()
+
+
+def per_example_full_exp_loss(labels, weights, predictions, name=None):
+ """Full exponential loss given labels, example weights and predictions.
+
+ Note that this is only for binary classification.
+ The loss returns is exp(-targets*logits), where targets are converted to -1
+ and 1.
+
+ Args:
+ labels: Rank 2 (N, D) tensor of per-example labels.
+ weights: Rank 2 (N, 1) tensor of per-example weights.
+ predictions: Rank 2 (N, D) tensor of per-example predictions.
+ name: A name for the operation (optional).
+
+ Returns:
+ loss: A Rank 2 (N, 1) tensor of per-example exp loss
+ update_op: An update operation to update the loss's internal state.
+ """
+
+ def full_exp_with_logits(name, labels=None, logits=None):
+ """Computes exponential loss given `logits`.
+
+ Args:
+ name: A name for the operation (optional).
+ labels: A `Tensor` of the same type and shape as `logits`.
+ logits: A `Tensor` of type `float32` or `float64`.
+
+ Returns:
+ A `Tensor` of the same shape as `logits` with the componentwise
+ exponential losses.
+
+ Raises:
+ ValueError: If `logits` and `labels` do not have the same shape.
+ """
+ with ops.name_scope(name, "exp_loss", [logits, labels]) as name:
+ logits = ops.convert_to_tensor(logits, name="logits")
+ labels = ops.convert_to_tensor(labels, name="labels")
+ try:
+ labels.get_shape().merge_with(logits.get_shape())
+ except ValueError:
+ raise ValueError("logits and labels must have the same shape (%s vs %s)"
+ % (logits.get_shape(), labels.get_shape()))
+
+ # Default threshold of 0 to switch between classes
+ zeros = array_ops.zeros_like(logits, dtype=logits.dtype)
+ ones = array_ops.ones_like(logits, dtype=logits.dtype)
+ neg_ones = -array_ops.ones_like(logits, dtype=logits.dtype)
+
+ # Convert labels to 1 and -1
+ cond_labels = (labels > zeros)
+ labels_converted = array_ops.where(cond_labels, ones, neg_ones)
+
+ return math_ops.exp(-1.0 * logits * labels_converted)
+
+ labels = math_ops.to_float(labels)
+ unweighted_loss = full_exp_with_logits(
+ name=name, labels=labels, logits=predictions)
+ return unweighted_loss * weights, control_flow_ops.no_op()
diff --git a/tensorflow/contrib/cluster_resolver/BUILD b/tensorflow/contrib/cluster_resolver/BUILD
index c239e6f8f9..707f621184 100644
--- a/tensorflow/contrib/cluster_resolver/BUILD
+++ b/tensorflow/contrib/cluster_resolver/BUILD
@@ -12,6 +12,15 @@ licenses(["notice"]) # Apache 2.0
py_library(
name = "cluster_resolver_pip",
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":cluster_resolver_py",
+ ],
+)
+
+py_library(
+ name = "cluster_resolver_py",
srcs = [
"__init__.py",
"python/training/__init__.py",
@@ -19,7 +28,7 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
- ":cluster_resolver_py",
+ ":base_cluster_resolver_py",
":gce_cluster_resolver_py",
":tpu_cluster_resolver_py",
"//tensorflow/python:util",
@@ -27,7 +36,7 @@ py_library(
)
py_library(
- name = "cluster_resolver_py",
+ name = "base_cluster_resolver_py",
srcs = ["python/training/cluster_resolver.py"],
srcs_version = "PY2AND3",
deps = [
@@ -40,7 +49,7 @@ py_library(
srcs = ["python/training/gce_cluster_resolver.py"],
srcs_version = "PY2AND3",
deps = [
- ":cluster_resolver_py",
+ ":base_cluster_resolver_py",
"//tensorflow/python:training",
],
)
@@ -50,13 +59,13 @@ py_library(
srcs = ["python/training/tpu_cluster_resolver.py"],
srcs_version = "PY2AND3",
deps = [
- ":cluster_resolver_py",
+ ":base_cluster_resolver_py",
"//tensorflow/python:training",
],
)
tf_py_test(
- name = "cluster_resolver_py_test",
+ name = "base_cluster_resolver_py_test",
srcs = ["python/training/cluster_resolver_test.py"],
additional_deps = [
":cluster_resolver_py",
diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt
index a5eba5a8c9..75e00f3267 100644
--- a/tensorflow/contrib/cmake/python_modules.txt
+++ b/tensorflow/contrib/cmake/python_modules.txt
@@ -14,6 +14,7 @@ tensorflow/examples/tutorials
tensorflow/examples/tutorials/mnist
tensorflow/python
tensorflow/python/client
+tensorflow/python/compat
tensorflow/python/data
tensorflow/python/data/ops
tensorflow/python/data/util
@@ -61,6 +62,8 @@ tensorflow/python/saved_model
tensorflow/python/summary
tensorflow/python/summary/writer
tensorflow/python/tools
+tensorflow/python/tools/api
+tensorflow/python/tools/api/generator
tensorflow/python/training
tensorflow/python/training/checkpointable
tensorflow/python/user_ops
@@ -68,7 +71,6 @@ tensorflow/python/util
tensorflow/python/util/protobuf
tensorflow/tools
tensorflow/tools/api
-tensorflow/tools/api/generator
tensorflow/tools/graph_transforms
tensorflow/contrib
tensorflow/contrib/all_reduce
diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake
index 872b016d2b..067c299a71 100644
--- a/tensorflow/contrib/cmake/tf_core_framework.cmake
+++ b/tensorflow/contrib/cmake/tf_core_framework.cmake
@@ -49,48 +49,43 @@ function(RELATIVE_PROTOBUF_GENERATE_CPP SRCS HDRS ROOT_DIR)
set(${HDRS} ${${HDRS}} PARENT_SCOPE)
endfunction()
-function(RELATIVE_PROTOBUF_GENERATE_GRPC_CPP SRCS HDRS ROOT_DIR)
- if(NOT ARGN)
- message(SEND_ERROR "Error: RELATIVE_PROTOBUF_GENERATE_GRPC_CPP() called without any proto files")
- return()
- endif()
-
- set(${SRCS})
- set(${HDRS})
- foreach(FIL ${ARGN})
- set(ABS_FIL ${ROOT_DIR}/${FIL})
- get_filename_component(FIL_WE ${FIL} NAME_WE)
- get_filename_component(FIL_DIR ${ABS_FIL} PATH)
- file(RELATIVE_PATH REL_DIR ${ROOT_DIR} ${FIL_DIR})
-
- list(APPEND ${SRCS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.grpc.pb.cc")
- list(APPEND ${HDRS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.grpc.pb.h")
- list(APPEND ${SRCS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.cc")
- list(APPEND ${HDRS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.h")
-
- # We adust the path of the gRPC code generation accordingly.
- if(WIN32)
- set(GRPC_PROTOC_PLUGIN_PATH ${GRPC_BUILD}/Release/grpc_cpp_plugin.exe)
- else()
- set(GRPC_PROTOC_PLUGIN_PATH ${GRPC_BUILD}/grpc_cpp_plugin)
+if(NOT WIN32)
+ function(RELATIVE_PROTOBUF_GENERATE_GRPC_CPP SRCS HDRS ROOT_DIR)
+ if(NOT ARGN)
+ message(SEND_ERROR "Error: RELATIVE_PROTOBUF_GENERATE_GRPC_CPP() called without any proto files")
+ return()
endif()
- add_custom_command(
- OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.grpc.pb.cc"
- "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.grpc.pb.h"
- "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.cc"
- "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.h"
- COMMAND ${PROTOBUF_PROTOC_EXECUTABLE}
- ARGS --grpc_out ${CMAKE_CURRENT_BINARY_DIR} --cpp_out ${CMAKE_CURRENT_BINARY_DIR} --plugin=protoc-gen-grpc=${GRPC_PROTOC_PLUGIN_PATH} -I ${ROOT_DIR} ${ABS_FIL} -I ${PROTOBUF_INCLUDE_DIRS}
- DEPENDS ${ABS_FIL} protobuf grpc
- COMMENT "Running C++ protocol buffer grpc compiler on ${FIL}"
- VERBATIM )
- endforeach()
-
- set_source_files_properties(${${SRCS}} ${${HDRS}} PROPERTIES GENERATED TRUE)
- set(${SRCS} ${${SRCS}} PARENT_SCOPE)
- set(${HDRS} ${${HDRS}} PARENT_SCOPE)
-endfunction()
+ set(${SRCS})
+ set(${HDRS})
+ foreach(FIL ${ARGN})
+ set(ABS_FIL ${ROOT_DIR}/${FIL})
+ get_filename_component(FIL_WE ${FIL} NAME_WE)
+ get_filename_component(FIL_DIR ${ABS_FIL} PATH)
+ file(RELATIVE_PATH REL_DIR ${ROOT_DIR} ${FIL_DIR})
+
+ list(APPEND ${SRCS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.grpc.pb.cc")
+ list(APPEND ${HDRS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.grpc.pb.h")
+ list(APPEND ${SRCS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.cc")
+ list(APPEND ${HDRS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.h")
+
+ add_custom_command(
+ OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.grpc.pb.cc"
+ "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.grpc.pb.h"
+ "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.cc"
+ "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.h"
+ COMMAND ${PROTOBUF_PROTOC_EXECUTABLE}
+ ARGS --grpc_out ${CMAKE_CURRENT_BINARY_DIR} --cpp_out ${CMAKE_CURRENT_BINARY_DIR} --plugin protoc-gen-grpc=${GRPC_BUILD}/grpc_cpp_plugin -I ${ROOT_DIR} ${ABS_FIL} -I ${PROTOBUF_INCLUDE_DIRS}
+ DEPENDS ${ABS_FIL} protobuf grpc
+ COMMENT "Running C++ protocol buffer grpc compiler on ${FIL}"
+ VERBATIM )
+ endforeach()
+
+ set_source_files_properties(${${SRCS}} ${${HDRS}} PROPERTIES GENERATED TRUE)
+ set(${SRCS} ${${SRCS}} PARENT_SCOPE)
+ set(${HDRS} ${${HDRS}} PARENT_SCOPE)
+ endfunction()
+endif()
function(RELATIVE_PROTOBUF_TEXT_GENERATE_CPP SRCS HDRS ROOT_DIR)
if(NOT ARGN)
@@ -180,14 +175,17 @@ RELATIVE_PROTOBUF_TEXT_GENERATE_CPP(PROTO_TEXT_SRCS PROTO_TEXT_HDRS
${tensorflow_source_dir} ${tf_proto_text_srcs}
)
-file(GLOB_RECURSE tf_protos_grpc_cc_srcs RELATIVE ${tensorflow_source_dir}
- "${tensorflow_source_dir}/tensorflow/core/debug/*.proto"
- "${tensorflow_source_dir}/tensorflow/core/protobuf/master_service.proto"
-)
-RELATIVE_PROTOBUF_GENERATE_GRPC_CPP(PROTO_GRPC_SRCS PROTO_GRPC_HDRS
- ${tensorflow_source_dir} ${tf_protos_grpc_cc_srcs}
-)
-add_library(tf_protos_cc ${PROTO_GRPC_SRCS} ${PROTO_GRPC_HDRS} ${PROTO_SRCS} ${PROTO_HDRS})
+if(WIN32)
+ add_library(tf_protos_cc ${PROTO_SRCS} ${PROTO_HDRS})
+else()
+ file(GLOB_RECURSE tf_protos_grpc_cc_srcs RELATIVE ${tensorflow_source_dir}
+ "${tensorflow_source_dir}/tensorflow/core/debug/*.proto"
+ )
+ RELATIVE_PROTOBUF_GENERATE_GRPC_CPP(PROTO_GRPC_SRCS PROTO_GRPC_HDRS
+ ${tensorflow_source_dir} ${tf_protos_grpc_cc_srcs}
+ )
+ add_library(tf_protos_cc ${PROTO_GRPC_SRCS} ${PROTO_GRPC_HDRS} ${PROTO_SRCS} ${PROTO_HDRS})
+endif()
########################################################
# tf_core_lib library
diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake
index e3b59001bc..32b185f07b 100755
--- a/tensorflow/contrib/cmake/tf_python.cmake
+++ b/tensorflow/contrib/cmake/tf_python.cmake
@@ -736,8 +736,8 @@ endif()
# Generate API __init__.py files.
########################################################
-# Parse tensorflow/tools/api/generator/BUILD to get list of generated files.
-FILE(READ ${tensorflow_source_dir}/tensorflow/tools/api/generator/api_gen.bzl api_generator_BUILD_text)
+# Parse tensorflow/python/tools/api/generator/BUILD to get list of generated files.
+FILE(READ ${tensorflow_source_dir}/tensorflow/python/tools/api/generator/api_gen.bzl api_generator_BUILD_text)
STRING(REGEX MATCH "# BEGIN GENERATED FILES.*# END GENERATED FILES" api_init_files_text ${api_generator_BUILD_text})
string(REPLACE "# BEGIN GENERATED FILES" "" api_init_files_text ${api_init_files_text})
string(REPLACE "# END GENERATED FILES" "" api_init_files_text ${api_init_files_text})
@@ -781,7 +781,7 @@ if (tensorflow_ENABLE_MKL_SUPPORT)
# Run create_python_api.py to generate API init files.
COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python PATH=${PY_RUNTIME_ENV} ${PYTHON_EXECUTABLE}
- "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py"
+ "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/tools/api/generator/create_python_api.py"
"--root_init_template=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/api_template.__init__.py"
"--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow"
"--package=tensorflow.python"
@@ -803,7 +803,7 @@ else (tensorflow_ENABLE_MKL_SUPPORT)
# Run create_python_api.py to generate API init files.
COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python ${PYTHON_EXECUTABLE}
- "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py"
+ "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/tools/api/generator/create_python_api.py"
"--root_init_template=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/api_template.__init__.py"
"--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow"
"--package=tensorflow.python"
@@ -824,8 +824,8 @@ add_dependencies(tf_python_api tf_python_ops)
# Generate API __init__.py files for tf.estimator.
########################################################
-# Parse tensorflow/tools/api/generator/BUILD to get list of generated files.
-FILE(READ ${tensorflow_source_dir}/tensorflow/tools/api/generator/api_gen.bzl api_generator_BUILD_text)
+# Parse tensorflow/python/tools/api/generator/BUILD to get list of generated files.
+FILE(READ ${tensorflow_source_dir}/tensorflow/python/tools/api/generator/api_gen.bzl api_generator_BUILD_text)
STRING(REGEX MATCH "# BEGIN GENERATED ESTIMATOR FILES.*# END GENERATED ESTIMATOR FILES" api_init_files_text ${api_generator_BUILD_text})
string(REPLACE "# BEGIN GENERATED ESTIMATOR FILES" "" api_init_files_text ${api_init_files_text})
string(REPLACE "# END GENERATED ESTIMATOR FILES" "" api_init_files_text ${api_init_files_text})
@@ -849,10 +849,11 @@ add_custom_command(
# Run create_python_api.py to generate API init files.
COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python ${PYTHON_EXECUTABLE}
- "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py"
+ "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/tools/api/generator/create_python_api.py"
"--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/estimator/api"
"--package=tensorflow.python.estimator"
"--apiname=estimator"
+ "--output_package=tensorflow.python.estimator.api"
"${estimator_api_init_list_file}"
COMMENT "Generating __init__.py files for Python API."
diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake
index eb9482dc25..b2330c4e34 100644
--- a/tensorflow/contrib/cmake/tf_tests.cmake
+++ b/tensorflow/contrib/cmake/tf_tests.cmake
@@ -193,6 +193,7 @@ if (tensorflow_BUILD_PYTHON_TESTS)
# flaky test
"${tensorflow_source_dir}/tensorflow/python/profiler/internal/run_metadata_test.py"
"${tensorflow_source_dir}/tensorflow/python/profiler/model_analyzer_test.py"
+ "${tensorflow_source_dir}/tensorflow/python/data/kernel_tests/map_dataset_op_test.py"
# Fails because uses data dependencies with bazel
"${tensorflow_source_dir}/tensorflow/python/saved_model/saved_model_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/image/python/kernel_tests/sparse_image_warp_test.py"
@@ -216,7 +217,8 @@ if (tensorflow_BUILD_PYTHON_TESTS)
${tensorflow_source_dir}/tensorflow/python/kernel_tests/duplicate_op_test.py
${tensorflow_source_dir}/tensorflow/python/kernel_tests/invalid_op_test.py
${tensorflow_source_dir}/tensorflow/python/kernel_tests/ackermann_test.py
-
+ # Tests too large to run.
+ ${tensorflow_source_dir}/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py
)
if (WIN32)
set(tf_test_src_py_exclude
diff --git a/tensorflow/contrib/copy_graph/python/util/copy_elements.py b/tensorflow/contrib/copy_graph/python/util/copy_elements.py
index a0dd3881a8..5931c8a279 100644
--- a/tensorflow/contrib/copy_graph/python/util/copy_elements.py
+++ b/tensorflow/contrib/copy_graph/python/util/copy_elements.py
@@ -18,7 +18,7 @@ These functions allow for recursive copying of elements (ops and variables)
from one graph to another. The copied elements are initialized inside a
user-specified scope in the other graph. There are separate functions to
copy ops and variables.
-There is also a function to retrive the copied version of an op from the
+There is also a function to retrieve the copied version of an op from the
first graph inside a scope in the second graph.
@@copy_op_to_graph
@@ -77,7 +77,7 @@ def copy_variable_to_graph(org_instance, to_graph, scope=''):
else:
collections.append(scope + '/' + name)
- #See if its trainable.
+ #See if it's trainable.
trainable = (
org_instance in org_instance.graph.get_collection(
ops.GraphKeys.TRAINABLE_VARIABLES))
@@ -162,7 +162,7 @@ def copy_op_to_graph(org_instance, to_graph, variables, scope=''):
if isinstance(org_instance, ops.Tensor):
- #If its a Tensor, it is one of the outputs of the underlying
+ #If it's a Tensor, it is one of the outputs of the underlying
#op. Therefore, copy the op itself and return the appropriate
#output.
op = org_instance.op
diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py
index 156538b4e0..675330716b 100644
--- a/tensorflow/contrib/data/__init__.py
+++ b/tensorflow/contrib/data/__init__.py
@@ -34,6 +34,7 @@ See @{$guide/datasets$Importing Data} for an overview.
@@batch_and_drop_remainder
@@bucket_by_sequence_length
@@choose_from_datasets
+@@copy_to_device
@@dense_to_sparse_batch
@@enumerate_dataset
@@ -86,6 +87,7 @@ from tensorflow.contrib.data.python.ops.interleave_ops import sample_from_datase
from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave
from tensorflow.contrib.data.python.ops.iterator_ops import CheckpointInputPipelineHook
from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator
+from tensorflow.contrib.data.python.ops.prefetching_ops import copy_to_device
from tensorflow.contrib.data.python.ops.prefetching_ops import prefetch_to_device
from tensorflow.contrib.data.python.ops.random_ops import RandomDataset
from tensorflow.contrib.data.python.ops.readers import CsvDataset
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index 079c8bbd8e..9a454efc4c 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -229,9 +229,11 @@ cuda_py_test(
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:function",
"//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python/compat:compat",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator_ops",
],
+ tags = ["no_windows_gpu"],
)
py_test(
diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
index a075dfd8b5..b7025f3802 100644
--- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
@@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import hashlib
import itertools
import os
import time
@@ -32,9 +33,12 @@ from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import io_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
from tensorflow.python.util import compat
+_NUMPY_RANDOM_SEED = 42
+
class MapDatasetTest(test.TestCase):
@@ -142,80 +146,123 @@ class MapDatasetTest(test.TestCase):
class MapDatasetBenchmark(test.Benchmark):
+ # The purpose of this benchmark is to compare the performance of chaining vs
+ # fusing of the map and batch transformations across various configurations.
+ #
+ # NOTE: It is recommended to build the benchmark with
+ # `-c opt --copt=-mavx --copt=-mavx2 --copt=-mfma --copt=-gmlt`
+ # and execute it on a machine with at least 32 CPU cores.
def benchmarkMapAndBatch(self):
- small = itertools.product([1, 4], [1, 4], [1, 4], [16, 64], [100])
- large = itertools.product([16, 64], [16, 64], [16, 64], [256, 1024], [10])
-
- num_iters = 100
-
- def benchmark(series):
- for num_calls, inter_op, element_size, batch_size, num_steps in series:
- dataset = dataset_ops.Dataset.from_tensors(
- np.random.randint(100, size=element_size)).repeat().map(
- lambda x: x,
- num_parallel_calls=num_calls).batch(batch_size=batch_size)
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
+ # Sequential pipeline configurations.
+ seq_elem_size_series = itertools.product([1], [1], [1, 2, 4, 8], [16])
+ seq_batch_size_series = itertools.product([1], [1], [1], [8, 16, 32, 64])
+
+ # Parallel pipeline configuration.
+ par_elem_size_series = itertools.product([32], [32], [1, 2, 4, 8], [256])
+ par_batch_size_series = itertools.product([32], [32], [1],
+ [128, 256, 512, 1024])
+ par_num_calls_series = itertools.product([8, 16, 32, 64], [32], [1], [512])
+ par_inter_op_series = itertools.product([32], [8, 16, 32, 64], [1], [512])
+
+ def name(method, label, num_calls, inter_op, element_size, batch_size):
+ return ("%s_id_%s_num_calls_%d_inter_op_%d_elem_size_%d_batch_size_%d" % (
+ method,
+ hashlib.sha1(label).hexdigest(),
+ num_calls,
+ inter_op,
+ element_size,
+ batch_size,
+ ))
+
+ def benchmark(label, series):
+
+ print("%s:" % label)
+ for num_calls, inter_op, element_size, batch_size in series:
+
+ num_iters = 1024 // (
+ (element_size * batch_size) // min(num_calls, inter_op))
+ k = 1024 * 1024
+ dataset = dataset_ops.Dataset.from_tensors((np.random.rand(
+ element_size, 4 * k), np.random.rand(4 * k, 1))).repeat()
+
+ chained_dataset = dataset.map(
+ math_ops.matmul,
+ num_parallel_calls=num_calls).batch(batch_size=batch_size)
+ chained_iterator = chained_dataset.make_one_shot_iterator()
+ chained_get_next = chained_iterator.get_next()
- fused_dataset = dataset_ops.Dataset.from_tensors(
- np.random.randint(100, size=element_size)).repeat(None).apply(
- batching.map_and_batch(
- lambda x: x,
- num_parallel_calls=num_calls,
- batch_size=batch_size))
- fused_iterator = fused_dataset.make_one_shot_iterator()
- fused_get_next = fused_iterator.get_next()
-
- fused_deltas = []
+ chained_deltas = []
with session.Session(
config=config_pb2.ConfigProto(
- inter_op_parallelism_threads=inter_op)) as sess:
-
+ inter_op_parallelism_threads=inter_op,
+ use_per_session_threads=True)) as sess:
for _ in range(5):
- sess.run(fused_get_next)
+ sess.run(chained_get_next.op)
for _ in range(num_iters):
start = time.time()
- for _ in range(num_steps):
- sess.run(fused_get_next)
+ sess.run(chained_get_next.op)
end = time.time()
- fused_deltas.append(end - start)
+ chained_deltas.append(end - start)
- chained_deltas = []
+ fused_dataset = dataset = dataset.apply(
+ batching.map_and_batch(
+ math_ops.matmul,
+ num_parallel_calls=num_calls,
+ batch_size=batch_size))
+ fused_iterator = fused_dataset.make_one_shot_iterator()
+ fused_get_next = fused_iterator.get_next()
+
+ fused_deltas = []
with session.Session(
config=config_pb2.ConfigProto(
- inter_op_parallelism_threads=inter_op)) as sess:
+ inter_op_parallelism_threads=inter_op,
+ use_per_session_threads=True)) as sess:
+
for _ in range(5):
- sess.run(get_next)
+ sess.run(fused_get_next.op)
for _ in range(num_iters):
start = time.time()
- for _ in range(num_steps):
- sess.run(get_next)
+ sess.run(fused_get_next.op)
end = time.time()
- chained_deltas.append(end - start)
+ fused_deltas.append(end - start)
- chained_wall_time = np.median(chained_deltas) / num_iters
- fused_wall_time = np.median(fused_deltas) / num_iters
print(
"batch size: %d, num parallel calls: %d, inter-op parallelism: %d, "
- "element size: %d, chained wall time: %f, fused wall time: %f" %
- (batch_size, num_calls, inter_op, element_size, chained_wall_time,
- fused_wall_time))
+ "element size: %d, num iters: %d\nchained wall time: %f (median), "
+ "%f (mean), %f (stddev), %f (min), %f (max)\n fused wall time: "
+ "%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n "
+ "chained/fused: %.2fx (median), %.2fx (mean)" %
+ (batch_size, num_calls, inter_op, element_size, num_iters,
+ np.median(chained_deltas), np.mean(chained_deltas),
+ np.std(chained_deltas), np.min(chained_deltas),
+ np.max(chained_deltas), np.median(fused_deltas),
+ np.mean(fused_deltas), np.std(fused_deltas), np.min(fused_deltas),
+ np.max(fused_deltas),
+ np.median(chained_deltas) / np.median(fused_deltas),
+ np.mean(chained_deltas) / np.mean(fused_deltas)))
self.report_benchmark(
iters=num_iters,
- wall_time=chained_wall_time,
- name="chained_batch_size_%d_num_calls_%d_inter_op_%d_elem_size_%d"
- % (batch_size, num_calls, inter_op, element_size))
+ wall_time=np.median(chained_deltas),
+ name=name("chained", label, num_calls, inter_op, element_size,
+ batch_size))
self.report_benchmark(
iters=num_iters,
- wall_time=fused_wall_time,
- name="fused_batch_size_%d_num_calls_%d_inter_op_%d_elem_size_%d"
- % (batch_size, num_calls, inter_op, element_size))
-
- benchmark(small)
- benchmark(large)
+ wall_time=np.median(fused_deltas),
+ name=name("fused", label, num_calls, inter_op, element_size,
+ batch_size))
+
+ print("")
+
+ np.random.seed(_NUMPY_RANDOM_SEED)
+ benchmark("Sequential element size evaluation", seq_elem_size_series)
+ benchmark("Sequential batch size evaluation", seq_batch_size_series)
+ benchmark("Parallel element size evaluation", par_elem_size_series)
+ benchmark("Parallel batch size evaluation", par_batch_size_series)
+ benchmark("Transformation parallelism evaluation", par_num_calls_series)
+ benchmark("Threadpool size evaluation", par_inter_op_series)
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py
index e35be8a23f..21eebccd11 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py
@@ -35,8 +35,6 @@ class OptimizeDatasetTest(test.TestCase):
with self.test_session() as sess:
graph = graph_pb2.GraphDef().FromString(
sess.run(dataset._as_serialized_graph()))
- self.assertTrue(
- all([node.op != "MapAndBatchDatasetV2" for node in graph.node]))
self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -50,8 +48,6 @@ class OptimizeDatasetTest(test.TestCase):
with self.test_session() as sess:
graph = graph_pb2.GraphDef().FromString(
sess.run(dataset._as_serialized_graph()))
- self.assertTrue(
- all([node.op != "MapAndBatchDatasetV2" for node in graph.node]))
self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -65,12 +61,21 @@ class OptimizeDatasetTest(test.TestCase):
with self.test_session() as sess:
graph = graph_pb2.GraphDef().FromString(
sess.run(dataset._as_serialized_graph()))
- self.assertTrue(
- any([node.op == "MapAndBatchDatasetV2" for node in graph.node]))
self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
+ def testFunctionLibraryDefinitionModification(self):
+ dataset = dataset_ops.Dataset.from_tensors(0).map(lambda x: x).apply(
+ optimization.optimize(["_test_only_function_rename"]))
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ with self.assertRaisesRegexp(errors.NotFoundError,
+ "Function .* is not defined."):
+ sess.run(get_next)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
index 40a8e46676..82543b1039 100644
--- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
@@ -21,6 +21,7 @@ import threading
from tensorflow.contrib.data.python.ops import prefetching_ops
from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.compat import compat
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import constant_op
@@ -86,8 +87,7 @@ class PrefetchingKernelsOpsTest(test.TestCase):
return (prefetch_op, reset_op, destroy_op)
def _prefetch_fn_helper_one_shot(self, buffer_name, device0, device1):
- worker_config = config_pb2.ConfigProto()
- worker_config.device_count["CPU"] = 2
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=False)
prefetch_op, _, destroy_op = self._create_ops(ds, ds_iterator, buffer_name,
@@ -126,8 +126,7 @@ class PrefetchingKernelsOpsTest(test.TestCase):
"/job:localhost/replica:0/task:0/gpu:0")
def testReinitialization(self):
- worker_config = config_pb2.ConfigProto()
- worker_config.device_count["CPU"] = 2
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
device0 = "/job:localhost/replica:0/task:0/cpu:0"
device1 = "/job:localhost/replica:0/task:0/cpu:1"
@@ -167,8 +166,7 @@ class PrefetchingKernelsOpsTest(test.TestCase):
sess.run(destroy_op)
def testReinitializationOutOfRange(self):
- worker_config = config_pb2.ConfigProto()
- worker_config.device_count["CPU"] = 2
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
device0 = "/job:localhost/replica:0/task:0/cpu:0"
device1 = "/job:localhost/replica:0/task:0/cpu:1"
@@ -271,8 +269,7 @@ class PrefetchToDeviceTest(test.TestCase):
self.assertEqual(dtypes.int64, next_element.dtype)
self.assertEqual([], next_element.shape)
- worker_config = config_pb2.ConfigProto()
- worker_config.device_count["CPU"] = 2
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config) as sess:
for i in range(10):
self.assertEqual(i, sess.run(next_element))
@@ -332,8 +329,7 @@ class PrefetchToDeviceTest(test.TestCase):
self.assertEqual(dtypes.int64, next_element["a"].dtype)
self.assertEqual([], next_element["a"].shape)
- worker_config = config_pb2.ConfigProto()
- worker_config.device_count["CPU"] = 2
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config) as sess:
for i in range(10):
self.assertEqual({"a": i}, sess.run(next_element))
@@ -366,8 +362,7 @@ class PrefetchToDeviceTest(test.TestCase):
next_element = iterator.get_next()
self.assertEqual(dtypes.int64, next_element.dtype)
- worker_config = config_pb2.ConfigProto()
- worker_config.device_count["CPU"] = 2
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config) as sess:
for i in range(10):
actual = sess.run(next_element)
@@ -417,8 +412,7 @@ class PrefetchToDeviceTest(test.TestCase):
self.assertEqual(dtypes.int64, next_element.dtype)
self.assertEqual([], next_element.shape)
- worker_config = config_pb2.ConfigProto()
- worker_config.device_count["CPU"] = 2
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config) as sess:
sess.run(iterator.initializer)
for i in range(5):
@@ -451,5 +445,467 @@ class PrefetchToDeviceTest(test.TestCase):
sess.run(next_element)
+class CopyToDeviceTest(test.TestCase):
+
+ def testCopyToDevice(self):
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:1"))
+
+ with ops.device("/cpu:1"):
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int64, next_element.dtype)
+ self.assertEqual([], next_element.shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceInt32(self):
+ host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3])
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:1"))
+
+ with ops.device("/cpu:1"):
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int32, next_element.dtype)
+ self.assertEqual((4,), next_element.shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToSameDevice(self):
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:0"))
+
+ with ops.device("/cpu:0"):
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int64, next_element.dtype)
+ self.assertEqual([], next_element.shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceWithPrefetch(self):
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:1")).prefetch(1)
+
+ with ops.device("/cpu:1"):
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int64, next_element.dtype)
+ self.assertEqual([], next_element.shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyDictToDevice(self):
+ host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:1"))
+
+ with ops.device("/cpu:1"):
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int64, next_element["a"].dtype)
+ self.assertEqual([], next_element["a"].shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ self.assertEqual({"a": i}, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyDictToDeviceWithPrefetch(self):
+ host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:1")).prefetch(1)
+
+ with ops.device("/cpu:1"):
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int64, next_element["a"].dtype)
+ self.assertEqual([], next_element["a"].shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ self.assertEqual({"a": i}, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopySparseTensorsToDevice(self):
+
+ def make_tensor(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=[[0, 0]], values=(i * [1]), dense_shape=[2, 2])
+
+ host_dataset = dataset_ops.Dataset.range(10).map(make_tensor)
+
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:1"))
+
+ with ops.device("/cpu:1"):
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int64, next_element.dtype)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ actual = sess.run(next_element)
+ self.assertAllEqual([i], actual.values)
+ self.assertAllEqual([[0, 0]], actual.indices)
+ self.assertAllEqual([2, 2], actual.dense_shape)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopySparseTensorsToDeviceWithPrefetch(self):
+
+ def make_tensor(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=[[0, 0]], values=(i * [1]), dense_shape=[2, 2])
+
+ host_dataset = dataset_ops.Dataset.range(10).map(make_tensor)
+
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:1")).prefetch(1)
+
+ with ops.device("/cpu:1"):
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int64, next_element.dtype)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ actual = sess.run(next_element)
+ self.assertAllEqual([i], actual.values)
+ self.assertAllEqual([[0, 0]], actual.indices)
+ self.assertAllEqual([2, 2], actual.dense_shape)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceGpu(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/gpu:0"))
+
+ with ops.device("/gpu:0"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceGpuWithPrefetch(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/gpu:0")).prefetch(1)
+
+ with ops.device("/gpu:0"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceGpuInt32(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3])
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/gpu:0"))
+
+ with ops.device("/gpu:0"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(iterator.initializer)
+ self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceGpuInt32AndPrefetch(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3])
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/gpu:0")).prefetch(1)
+
+ with ops.device("/gpu:0"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(iterator.initializer)
+ self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceGpuStrings(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.from_tensors(["a", "b", "c"])
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/gpu:0"))
+
+ with ops.device("/gpu:0"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(iterator.initializer)
+ self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceGpuStringsAndPrefetch(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.from_tensors(["a", "b", "c"])
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/gpu:0"))
+
+ with ops.device("/gpu:0"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(iterator.initializer)
+ self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDevicePingPongCPUGPU(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ with compat.forward_compatibility_horizon(2018, 8, 4):
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/gpu:0", source_device="/cpu:0"))
+ back_to_cpu_dataset = device_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:0", source_device="/gpu:0"))
+
+ with ops.device("/cpu:0"):
+ iterator = back_to_cpu_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceWithReInit(self):
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:1"))
+
+ with ops.device("/cpu:1"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int64, next_element.dtype)
+ self.assertEqual([], next_element.shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ sess.run(iterator.initializer)
+ for i in range(5):
+ self.assertEqual(i, sess.run(next_element))
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceWithReInitAndPrefetch(self):
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:1")).prefetch(1)
+
+ with ops.device("/cpu:1"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int64, next_element.dtype)
+ self.assertEqual([], next_element.shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ sess.run(iterator.initializer)
+ for i in range(5):
+ self.assertEqual(i, sess.run(next_element))
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceGpuWithReInit(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/gpu:0"))
+
+ with ops.device("/gpu:0"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(5):
+ self.assertEqual(i, sess.run(next_element))
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceGpuWithReInitAndPrefetch(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/gpu:0")).prefetch(1)
+
+ with ops.device("/gpu:0"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(5):
+ self.assertEqual(i, sess.run(next_element))
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/ops/prefetching_ops.py b/tensorflow/contrib/data/python/ops/prefetching_ops.py
index 21fc17102e..50212d3b52 100644
--- a/tensorflow/contrib/data/python/ops/prefetching_ops.py
+++ b/tensorflow/contrib/data/python/ops/prefetching_ops.py
@@ -26,10 +26,14 @@ from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import sparse
from tensorflow.python.eager import context
+from tensorflow.python.framework import device as framework_device
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gen_dataset_ops as core_gen_dataset_ops
+from tensorflow.python.ops import resource_variable_ops
def function_buffering_resource(string_arg,
@@ -345,3 +349,172 @@ def prefetch_to_device(device, buffer_size=None):
return _PrefetchToDeviceDataset(dataset, device, buffer_size)
return _apply_fn
+
+
+def copy_to_device(target_device, source_device="/cpu:0"):
+ """A transformation that copies dataset elements to the given `target_device`.
+
+ Args:
+ target_device: The name of a device to which elements will be copied.
+ source_device: The original device on which `input_dataset` will be placed.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ @{tf.data.Dataset.apply}.
+ """
+
+ def _apply_fn(dataset):
+ return _CopyToDeviceDataset(
+ dataset, target_device=target_device, source_device=source_device)
+
+ return _apply_fn
+
+
+# TODO(rohanj): Use the _input_hostmem attr on the RemoteCall ops to indicate
+# all inputs to the Op are in host memory, thereby avoiding some unnecessary
+# Sends and Recvs.
+class _CopyToDeviceDataset(dataset_ops.Dataset):
+ """A `Dataset` that copies elements to another device."""
+
+ def __init__(self, input_dataset, target_device, source_device="/cpu:0"):
+ """Constructs a _CopyToDeviceDataset.
+
+ Args:
+ input_dataset: `Dataset` to be copied
+ target_device: The name of the device to which elements would be copied.
+ source_device: Device where input_dataset would be placed.
+ """
+ self._input_dataset = input_dataset
+ self._target_device = target_device
+ spec = framework_device.DeviceSpec().from_string(self._target_device)
+ self._is_gpu_target = (spec.device_type == "GPU")
+ self._source_device_string = source_device
+ self._source_device = ops.convert_to_tensor(source_device)
+
+ self._flat_output_shapes = nest.flatten(
+ sparse.as_dense_shapes(self._input_dataset.output_shapes,
+ self._input_dataset.output_classes))
+ self._flat_output_types = nest.flatten(
+ sparse.as_dense_types(self._input_dataset.output_types,
+ self._input_dataset.output_classes))
+
+ @function.Defun()
+ def _init_func():
+ """Creates an iterator for the input dataset.
+
+ Returns:
+ A `string` tensor that encapsulates the iterator created.
+ """
+ # pylint: disable=protected-access
+ ds_variant = self._input_dataset._as_variant_tensor()
+ resource = core_gen_dataset_ops.anonymous_iterator(
+ output_types=self._flat_output_types,
+ output_shapes=self._flat_output_shapes)
+ with ops.control_dependencies(
+ [core_gen_dataset_ops.make_iterator(ds_variant, resource)]):
+ return core_gen_dataset_ops.iterator_to_string_handle(resource)
+
+ @function.Defun()
+ def _remote_init_func():
+ return functional_ops.remote_call(
+ target=self._source_device,
+ args=_init_func.captured_inputs,
+ Tout=[dtypes.string],
+ f=_init_func)
+
+ self._init_func = _remote_init_func
+ self._init_captured_args = _remote_init_func.captured_inputs
+
+ @function.Defun(dtypes.string)
+ def _next_func(string_handle):
+ """Calls get_next for created iterator.
+
+ Args:
+ string_handle: An iterator string handle created by _init_func
+ Returns:
+ The elements generated from `input_dataset`
+ """
+ with ops.device(self._source_device_string):
+ iterator = iterator_ops.Iterator.from_string_handle(
+ string_handle, self.output_types, self.output_shapes,
+ self.output_classes)
+ ret = iterator.get_next()
+ return nest.flatten(sparse.serialize_sparse_tensors(ret))
+
+ @function.Defun(dtypes.string)
+ def _remote_next_func(string_handle):
+ return functional_ops.remote_call(
+ target=self._source_device,
+ args=[string_handle] + _next_func.captured_inputs,
+ Tout=self._flat_output_types,
+ f=_next_func)
+
+ self._next_func = _remote_next_func
+ self._next_captured_args = _remote_next_func.captured_inputs
+
+ @function.Defun(dtypes.string)
+ def _finalize_func(string_handle):
+ """Destroys the iterator resource created.
+
+ Args:
+ string_handle: An iterator string handle created by _init_func
+ Returns:
+ Tensor constant 0
+ """
+ iterator_resource = core_gen_dataset_ops.iterator_from_string_handle_v2(
+ string_handle,
+ output_types=self._flat_output_types,
+ output_shapes=self._flat_output_shapes)
+ with ops.control_dependencies([
+ resource_variable_ops.destroy_resource_op(
+ iterator_resource, ignore_lookup_error=True)]):
+ return array_ops.constant(0, dtypes.int64)
+
+ @function.Defun(dtypes.string)
+ def _remote_finalize_func(string_handle):
+ return functional_ops.remote_call(
+ target=self._source_device,
+ args=[string_handle] + _finalize_func.captured_inputs,
+ Tout=[dtypes.int64],
+ f=_finalize_func)
+
+ self._finalize_func = _remote_finalize_func
+ self._finalize_captured_args = _remote_finalize_func.captured_inputs
+ # pylint: enable=protected-scope
+
+ # The one_shot_iterator implementation needs a 0 arg _make_dataset function
+ # that thereby captures all the inputs required to create the dataset. Since
+ # there are strings that are inputs to the GeneratorDataset which can't be
+ # placed on a GPU, this fails for the GPU case. Therefore, disabling it for
+ # GPU
+ def make_one_shot_iterator(self):
+ if self._is_gpu_target:
+ raise ValueError("Cannot create a one shot iterator when using "
+ "`tf.contrib.data.copy_to_device()` on GPU. Please use "
+ "`Dataset.make_initializable_iterator()` instead.")
+ else:
+ return super(_CopyToDeviceDataset, self).make_one_shot_iterator()
+
+ def _as_variant_tensor(self):
+ with ops.device(self._target_device):
+ return core_gen_dataset_ops.generator_dataset(
+ self._init_captured_args,
+ self._next_captured_args,
+ self._finalize_captured_args,
+ init_func=self._init_func,
+ next_func=self._next_func,
+ finalize_func=self._finalize_func,
+ output_types=self._flat_output_types,
+ output_shapes=self._flat_output_shapes)
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py
index 83095c7ba1..9373e37f5f 100644
--- a/tensorflow/contrib/data/python/ops/readers.py
+++ b/tensorflow/contrib/data/python/ops/readers.py
@@ -540,11 +540,11 @@ class CsvDataset(dataset_ops.Dataset):
The expected output of its iterations is:
```python
- next = dataset.make_one_shot_iterator().get_next()
+ next_element = dataset.make_one_shot_iterator().get_next()
with tf.Session() as sess:
while True:
try:
- print(sess.run(nxt))
+ print(sess.run(next_element))
except tf.errors.OutOfRangeError:
break
diff --git a/tensorflow/contrib/distribute/BUILD b/tensorflow/contrib/distribute/BUILD
index 74b2cd90a1..1126f76f58 100644
--- a/tensorflow/contrib/distribute/BUILD
+++ b/tensorflow/contrib/distribute/BUILD
@@ -30,6 +30,7 @@ py_library(
"//tensorflow/contrib/distribute/python:monitor",
"//tensorflow/contrib/distribute/python:one_device_strategy",
"//tensorflow/contrib/distribute/python:step_fn",
+ "//tensorflow/contrib/distribute/python:tpu_strategy",
"//tensorflow/python:training",
"//tensorflow/python:util",
],
diff --git a/tensorflow/contrib/distribute/__init__.py b/tensorflow/contrib/distribute/__init__.py
index 76711baf3a..2e2c3be853 100644
--- a/tensorflow/contrib/distribute/__init__.py
+++ b/tensorflow/contrib/distribute/__init__.py
@@ -24,6 +24,7 @@ from tensorflow.contrib.distribute.python.mirrored_strategy import MirroredStrat
from tensorflow.contrib.distribute.python.monitor import Monitor
from tensorflow.contrib.distribute.python.one_device_strategy import OneDeviceStrategy
from tensorflow.contrib.distribute.python.step_fn import *
+from tensorflow.contrib.distribute.python.tpu_strategy import TPUStrategy
from tensorflow.python.training.distribute import *
from tensorflow.python.util.all_util import remove_undocumented
@@ -41,6 +42,7 @@ _allowed_symbols = [
'StandardInputStep',
'StandardSingleLossStep',
'TowerContext',
+ 'TPUStrategy',
'get_cross_tower_context',
'get_distribution_strategy',
'get_loss_reduction',
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index eba0dd0ea3..40dbfa3dd2 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -587,6 +587,7 @@ cuda_py_test(
],
tags = [
"multi_and_single_gpu",
+ "no_windows_gpu",
"notsan",
],
)
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
index b597bce035..6a14b833d2 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
@@ -491,13 +491,14 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
components_mean = {}
def model_fn(device_id):
- tower_context = distribute_lib.get_tower_context()
- with tower_context.tower_local_var_scope(
- variable_scope.VariableAggregation.SUM):
- v_sum = variable_scope.variable(1.0)
- with tower_context.tower_local_var_scope(
- variable_scope.VariableAggregation.MEAN):
- v_mean = variable_scope.variable(4.0)
+ v_sum = variable_scope.variable(
+ 1.0,
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ aggregation=variable_scope.VariableAggregation.SUM)
+ v_mean = variable_scope.variable(
+ 4.0,
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ aggregation=variable_scope.VariableAggregation.MEAN)
self.assertTrue(isinstance(v_sum, values.TowerLocalVariable))
self.assertTrue(isinstance(v_mean, values.TowerLocalVariable))
updates = [v_sum.assign_add(2.0 + device_id),
@@ -700,10 +701,10 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
with context.graph_mode():
def model_fn():
- tower_context = distribute_lib.get_tower_context()
- with tower_context.tower_local_var_scope(
- variable_scope.VariableAggregation.SUM):
- v_sum = variable_scope.variable(1.0)
+ v_sum = variable_scope.variable(
+ 1.0,
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ aggregation=variable_scope.VariableAggregation.SUM)
self.assertTrue(isinstance(v_sum, values.TowerLocalVariable))
return v_sum
@@ -922,5 +923,49 @@ class MirroredVariableUpdateTest(test.TestCase):
self.assertEquals(4.5, self.evaluate(mirrored_var))
+class MirroredAndTowerLocalVariableInitializerTest(test.TestCase):
+ config = config_pb2.ConfigProto()
+ config.allow_soft_placement = True
+
+ def testAssignMirroredVarInitializer(self):
+ # This test is not eager compatible since in eager variables are initialized
+ # upon construction instead of once the initialization op is run.
+ with context.graph_mode():
+ def var_fn():
+ v = variable_scope.variable(1.0, name="foo")
+ return v
+
+ dist = mirrored_strategy.MirroredStrategy(
+ ["/device:GPU:0", "/device:CPU:0"])
+
+ with dist.scope():
+ mirrored_var = dist.call_for_each_tower(var_fn)
+ self.assertIsInstance(mirrored_var, values.MirroredVariable)
+ self.assertFalse(self.evaluate(mirrored_var.is_initialized()))
+ self.evaluate(mirrored_var.initializer)
+ self.assertTrue(self.evaluate(mirrored_var.is_initialized()))
+
+ def testAssignTowerLocalVarInitializer(self):
+ # This test is not eager compatible since in eager variables are initialized
+ # upon construction instead of once the initialization op is run.
+ with context.graph_mode():
+ def model_fn():
+ v_sum = variable_scope.variable(
+ 1.0,
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ aggregation=variable_scope.VariableAggregation.SUM)
+ self.assertTrue(isinstance(v_sum, values.TowerLocalVariable))
+ return v_sum
+
+ dist = mirrored_strategy.MirroredStrategy(
+ ["/device:GPU:0", "/device:CPU:0"])
+
+ with dist.scope():
+ tower_local_var = dist.call_for_each_tower(model_fn)
+ self.assertTrue(isinstance(tower_local_var, values.TowerLocalVariable))
+ self.assertFalse(self.evaluate(tower_local_var.is_initialized()))
+ self.evaluate(tower_local_var.initializer)
+ self.assertTrue(self.evaluate(tower_local_var.is_initialized()))
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/distribute/python/multi_worker_strategy.py b/tensorflow/contrib/distribute/python/multi_worker_strategy.py
index 0f21a42732..cbfe5df61d 100644
--- a/tensorflow/contrib/distribute/python/multi_worker_strategy.py
+++ b/tensorflow/contrib/distribute/python/multi_worker_strategy.py
@@ -46,7 +46,7 @@ class MultiWorkerMirroredStrategy(MirroredStrategy):
* **In-graph replication**: the `client` creates a single `tf.Graph` that
specifies tasks for devices on all workers. The `client` then creates a
client session which will talk to the `master` service of a `worker`. Then
- the `master` will parition the graph and distribute the work to all
+ the `master` will partition the graph and distribute the work to all
participating workers.
* **Worker**: A `worker` is a TensorFlow `task` that usually maps to one
physical machine. We will have multiple `worker`s with different `task`
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index b36ac563d2..1b5e00bc79 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -297,6 +297,12 @@ class MirroredVariable(DistributedVariable, Mirrored,
for v in six.itervalues(index):
v._mirrored_container = weakref.ref(self) # pylint: disable=protected-access
self._primary_var = primary_var
+ # tf.keras keeps track of variables initialized using this attribute. When
+ # tf.keras gets the default session, it initializes all uninitialized vars.
+ # We need to make _keras_initialized a member of MirroredVariable because
+ # without this it will use `__getattr__` which will delegate to a component
+ # variable.
+ self._keras_initialized = False
self._aggregation = aggregation
super(MirroredVariable, self).__init__(index)
@@ -348,6 +354,28 @@ class MirroredVariable(DistributedVariable, Mirrored,
def assign(self, *args, **kwargs):
return self._assign_func(f=state_ops.assign, *args, **kwargs)
+ def is_initialized(self, name=None):
+ # We have to cast the self._index.values() to a `list` because when we
+ # use `model_to_estimator` to run tf.keras models, self._index.values() is
+ # of type `dict_values` and not `list`.
+ values_list = list(self._index.values())
+ result = values_list[0].is_initialized()
+ # We iterate through the list of values except the last one to allow us to
+ # name the final `logical_and` op the same name that is passed by the user
+ # to the `is_initialized` op. For mirrored variables, the `is_initialized`
+ # op is a `logical_and` op.
+ for v in values_list[1:-1]:
+ result = math_ops.logical_and(result, v.is_initialized())
+ result = math_ops.logical_and(result, values_list[-1].is_initialized(),
+ name=name)
+ return result
+
+ @property
+ def initializer(self):
+ # return grouped ops of all the var initializations of component values of
+ # the mirrored variable
+ return control_flow_ops.group([v.initializer for v in self._index.values()])
+
@property
def aggregation(self):
return self._aggregation
@@ -435,6 +463,12 @@ class TowerLocalVariable(DistributedVariable, PerDevice,
def __init__(self, index, primary_var, aggregation):
self._primary_var = primary_var
self._aggregation = aggregation
+ # tf.keras keeps track of variables initialized using this attribute. When
+ # tf.keras gets the default session, it initializes all uninitialized vars.
+ # We need to make _keras_initialized a member of TowerLocalVariable because
+ # without this it will use `__getattr__` which will delegate to a component
+ # variable.
+ self._keras_initialized = False
super(TowerLocalVariable, self).__init__(index)
def assign_sub(self, *args, **kwargs):
@@ -449,6 +483,28 @@ class TowerLocalVariable(DistributedVariable, PerDevice,
_assert_tower_context()
return self.get().assign(*args, **kwargs)
+ def is_initialized(self, name=None):
+ # We have to cast the self._index.values() to a `list` because when we
+ # use `model_to_estimator` to run tf.keras models, self._index.values() is
+ # of type `dict_values` and not `list`.
+ values_list = list(self._index.values())
+ result = values_list[0].is_initialized()
+ # We iterate through the list of values except the last one to allow us to
+ # name the final `logical_and` op the same name that is passed by the user
+ # to the `is_initialized` op. For tower local variables, the
+ # `is_initialized` op is a `logical_and` op.
+ for v in values_list[1:-1]:
+ result = math_ops.logical_and(result, v.is_initialized())
+ result = math_ops.logical_and(result, values_list[-1].is_initialized(),
+ name=name)
+ return result
+
+ @property
+ def initializer(self):
+ # return grouped ops of all the var initializations of component values of
+ # the tower local variable
+ return control_flow_ops.group([v.initializer for v in self._index.values()])
+
@property
def aggregation(self):
return self._aggregation
diff --git a/tensorflow/contrib/eager/python/examples/densenet/BUILD b/tensorflow/contrib/eager/python/examples/densenet/BUILD
new file mode 100644
index 0000000000..de2a817d17
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/densenet/BUILD
@@ -0,0 +1,29 @@
+licenses(["notice"]) # Apache 2.0
+
+package(default_visibility = ["//tensorflow:internal"])
+
+load("//tensorflow:tensorflow.bzl", "cuda_py_test")
+
+py_binary(
+ name = "densenet",
+ srcs = ["densenet.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/contrib/eager/python:tfe",
+ ],
+)
+
+cuda_py_test(
+ name = "densenet_test",
+ srcs = ["densenet_test.py"],
+ additional_deps = [
+ ":densenet",
+ "//tensorflow/contrib/eager/python:tfe",
+ "//tensorflow:tensorflow_py",
+ ],
+ tags = [
+ "no_pip",
+ "optonly",
+ ],
+)
diff --git a/tensorflow/contrib/eager/python/examples/densenet/densenet.py b/tensorflow/contrib/eager/python/examples/densenet/densenet.py
new file mode 100644
index 0000000000..3a2b2de250
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/densenet/densenet.py
@@ -0,0 +1,274 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Densely Connected Convolutional Networks.
+
+Reference [
+Densely Connected Convolutional Networks](https://arxiv.org/abs/1608.06993)
+
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+l2 = tf.keras.regularizers.l2
+
+
+class ConvBlock(tf.keras.Model):
+ """Convolutional Block consisting of (batchnorm->relu->conv).
+
+ Arguments:
+ num_filters: number of filters passed to a convolutional layer.
+ bottleneck: if True, then a 1x1 Conv is performed followed by 3x3 Conv.
+ weight_decay: weight decay
+ dropout_rate: dropout rate.
+ """
+
+ def __init__(self, num_filters, bottleneck, weight_decay=1e-4,
+ dropout_rate=0):
+ super(ConvBlock, self).__init__()
+ self.bottleneck = bottleneck
+ inter_filter = num_filters * 4
+ # don't forget to set use_bias=False when using batchnorm
+ self.conv2 = tf.keras.layers.Conv2D(num_filters,
+ (3, 3),
+ padding="same",
+ use_bias=False,
+ kernel_initializer="he_normal",
+ kernel_regularizer=l2(weight_decay))
+ self.batchnorm1 = tf.keras.layers.BatchNormalization()
+ self.dropout = tf.keras.layers.Dropout(dropout_rate)
+
+ if self.bottleneck:
+ self.conv1 = tf.keras.layers.Conv2D(inter_filter,
+ (1, 1),
+ padding="same",
+ use_bias=False,
+ kernel_initializer="he_normal",
+ kernel_regularizer=l2(weight_decay))
+ self.batchnorm2 = tf.keras.layers.BatchNormalization()
+
+ def call(self, x, training=True):
+ output = self.batchnorm1(x, training=training)
+
+ if self.bottleneck:
+ output = self.conv1(tf.nn.relu(output))
+ output = self.batchnorm2(output, training=training)
+
+ output = self.conv2(tf.nn.relu(output))
+ output = self.dropout(output, training=training)
+
+ return output
+
+
+class TransitionBlock(tf.keras.Model):
+ """Transition Block to reduce the number of features.
+
+ Arguments:
+ num_filters: number of filters passed to a convolutional layer.
+ weight_decay: weight decay
+ dropout_rate: dropout rate.
+ """
+
+ def __init__(self, num_filters, weight_decay=1e-4, dropout_rate=0):
+ super(TransitionBlock, self).__init__()
+ self.batchnorm = tf.keras.layers.BatchNormalization()
+ self.conv = tf.keras.layers.Conv2D(num_filters,
+ (1, 1),
+ padding="same",
+ use_bias=False,
+ kernel_initializer="he_normal",
+ kernel_regularizer=l2(weight_decay))
+ self.avg_pool = tf.keras.layers.AveragePooling2D()
+
+ def call(self, x, training=True):
+ output = self.batchnorm(x, training=training)
+ output = self.conv(tf.nn.relu(output))
+ output = self.avg_pool(output)
+ return output
+
+
+class DenseBlock(tf.keras.Model):
+ """Dense Block consisting of ConvBlocks where each block's
+ output is concatenated with its input.
+
+ Arguments:
+ num_layers: Number of layers in each block.
+ growth_rate: number of filters to add per conv block.
+ bottleneck: boolean, that decides which part of ConvBlock to call.
+ weight_decay: weight decay
+ dropout_rate: dropout rate.
+ """
+
+ def __init__(self, num_layers, growth_rate, bottleneck,
+ weight_decay=1e-4, dropout_rate=0):
+ super(DenseBlock, self).__init__()
+ self.num_layers = num_layers
+
+ self.blocks = []
+ for _ in range(int(self.num_layers)):
+ self.blocks.append(ConvBlock(growth_rate,
+ bottleneck,
+ weight_decay,
+ dropout_rate))
+
+ def call(self, x, training=True):
+ for i in range(int(self.num_layers)):
+ output = self.blocks[i](x, training=training)
+ x = tf.concat([x, output], axis=-1)
+
+ return x
+
+
+class DenseNet(tf.keras.Model):
+ """Creating the Densenet Architecture.
+
+ Arguments:
+ depth_of_model: number of layers in the model.
+ growth_rate: number of filters to add per conv block.
+ num_of_blocks: number of dense blocks.
+ output_classes: number of output classes.
+ num_layers_in_each_block: number of layers in each block.
+ If -1, then we calculate this by (depth-3)/4.
+ If positive integer, then the it is used as the
+ number of layers per block.
+ If list or tuple, then this list is used directly.
+ bottleneck: boolean, to decide which part of conv block to call.
+ compression: reducing the number of inputs(filters) to the transition block.
+ weight_decay: weight decay
+ rate: dropout rate.
+ pool_initial: If True add a 7x7 conv with stride 2 followed by 3x3 maxpool
+ else, do a 3x3 conv with stride 1.
+ include_top: If true, GlobalAveragePooling Layer and Dense layer are
+ included.
+ """
+
+ def __init__(self, depth_of_model, growth_rate, num_of_blocks,
+ output_classes, num_layers_in_each_block,
+ bottleneck=True, compression=0.5, weight_decay=1e-4,
+ dropout_rate=0, pool_initial=False, include_top=True):
+ super(DenseNet, self).__init__()
+ self.depth_of_model = depth_of_model
+ self.growth_rate = growth_rate
+ self.num_of_blocks = num_of_blocks
+ self.output_classes = output_classes
+ self.num_layers_in_each_block = num_layers_in_each_block
+ self.bottleneck = bottleneck
+ self.compression = compression
+ self.weight_decay = weight_decay
+ self.dropout_rate = dropout_rate
+ self.pool_initial = pool_initial
+ self.include_top = include_top
+
+ # deciding on number of layers in each block
+ if isinstance(self.num_layers_in_each_block, list) or isinstance(
+ self.num_layers_in_each_block, tuple):
+ self.num_layers_in_each_block = list(self.num_layers_in_each_block)
+ else:
+ if self.num_layers_in_each_block == -1:
+ if self.num_of_blocks != 3:
+ raise ValueError(
+ "Number of blocks must be 3 if num_layers_in_each_block is -1")
+ if (self.depth_of_model - 4) % 3 == 0:
+ num_layers = (self.depth_of_model - 4) / 3
+ if self.bottleneck:
+ num_layers //= 2
+ self.num_layers_in_each_block = [num_layers] * self.num_of_blocks
+ else:
+ raise ValueError("Depth must be 3N+4 if num_layer_in_each_block=-1")
+ else:
+ self.num_layers_in_each_block = [
+ self.num_layers_in_each_block] * self.num_of_blocks
+
+ # setting the filters and stride of the initial covn layer.
+ if self.pool_initial:
+ init_filters = (7, 7)
+ stride = (2, 2)
+ else:
+ init_filters = (3, 3)
+ stride = (1, 1)
+
+ self.num_filters = 2 * self.growth_rate
+
+ # first conv and pool layer
+ self.conv1 = tf.keras.layers.Conv2D(self.num_filters,
+ init_filters,
+ strides=stride,
+ padding="same",
+ use_bias=False,
+ kernel_initializer="he_normal",
+ kernel_regularizer=l2(
+ self.weight_decay))
+ if self.pool_initial:
+ self.pool1 = tf.keras.layers.MaxPooling2D(pool_size=(3, 3),
+ strides=(2, 2),
+ padding="same")
+ self.batchnorm1 = tf.keras.layers.BatchNormalization()
+
+ self.batchnorm2 = tf.keras.layers.BatchNormalization()
+
+ # last pooling and fc layer
+ if self.include_top:
+ self.last_pool = tf.keras.layers.GlobalAveragePooling2D()
+ self.classifier = tf.keras.layers.Dense(self.output_classes)
+
+ # calculating the number of filters after each block
+ num_filters_after_each_block = [self.num_filters]
+ for i in range(1, self.num_of_blocks):
+ temp_num_filters = num_filters_after_each_block[i-1] + (
+ self.growth_rate * self.num_layers_in_each_block[i-1])
+ # using compression to reduce the number of inputs to the
+ # transition block
+ temp_num_filters = int(temp_num_filters * compression)
+ num_filters_after_each_block.append(temp_num_filters)
+
+ # dense block initialization
+ self.dense_blocks = []
+ self.transition_blocks = []
+ for i in range(self.num_of_blocks):
+ self.dense_blocks.append(DenseBlock(self.num_layers_in_each_block[i],
+ self.growth_rate,
+ self.bottleneck,
+ self.weight_decay,
+ self.dropout_rate))
+ if i+1 < self.num_of_blocks:
+ self.transition_blocks.append(
+ TransitionBlock(num_filters_after_each_block[i+1],
+ self.weight_decay,
+ self.dropout_rate))
+
+ def call(self, x, training=True):
+ output = self.conv1(x)
+
+ if self.pool_initial:
+ output = self.batchnorm1(output, training=training)
+ output = tf.nn.relu(output)
+ output = self.pool1(output)
+
+ for i in range(self.num_of_blocks - 1):
+ output = self.dense_blocks[i](output, training=training)
+ output = self.transition_blocks[i](output, training=training)
+
+ output = self.dense_blocks[
+ self.num_of_blocks - 1](output, training=training)
+ output = self.batchnorm2(output, training=training)
+ output = tf.nn.relu(output)
+
+ if self.include_top:
+ output = self.last_pool(output)
+ output = self.classifier(output)
+
+ return output
diff --git a/tensorflow/contrib/eager/python/examples/densenet/densenet_test.py b/tensorflow/contrib/eager/python/examples/densenet/densenet_test.py
new file mode 100644
index 0000000000..56d3362f3b
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/densenet/densenet_test.py
@@ -0,0 +1,83 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for various Densenet architectures."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+from tensorflow.contrib.eager.python.examples.densenet import densenet
+
+
+class DensenetTest(tf.test.TestCase):
+
+ def test_bottleneck_true(self):
+ depth = 7
+ growth_rate = 2
+ num_blocks = 3
+ output_classes = 10
+ num_layers_in_each_block = -1
+ batch_size = 1
+
+ model = densenet.DenseNet(depth, growth_rate, num_blocks,
+ output_classes, num_layers_in_each_block,
+ bottleneck=True, compression=0.5,
+ weight_decay=1e-4, dropout_rate=0,
+ pool_initial=False, include_top=True)
+
+ rand_input = tf.random_uniform((batch_size, 32, 32, 3))
+ output_shape = model(rand_input).shape
+ self.assertEqual(output_shape, (batch_size, output_classes))
+
+ def test_bottleneck_false(self):
+ depth = 7
+ growth_rate = 2
+ num_blocks = 3
+ output_classes = 10
+ num_layers_in_each_block = -1
+ batch_size = 1
+
+ model = densenet.DenseNet(depth, growth_rate, num_blocks,
+ output_classes, num_layers_in_each_block,
+ bottleneck=False, compression=0.5,
+ weight_decay=1e-4, dropout_rate=0,
+ pool_initial=False, include_top=True)
+
+ rand_input = tf.random_uniform((batch_size, 32, 32, 3))
+ output_shape = model(rand_input).shape
+ self.assertEqual(output_shape, (batch_size, output_classes))
+
+ def test_pool_initial_true(self):
+ depth = 7
+ growth_rate = 2
+ num_blocks = 4
+ output_classes = 10
+ num_layers_in_each_block = [1, 2, 2, 1]
+ batch_size = 1
+
+ model = densenet.DenseNet(depth, growth_rate, num_blocks,
+ output_classes, num_layers_in_each_block,
+ bottleneck=True, compression=0.5,
+ weight_decay=1e-4, dropout_rate=0,
+ pool_initial=True, include_top=True)
+
+ rand_input = tf.random_uniform((batch_size, 32, 32, 3))
+ output_shape = model(rand_input).shape
+ self.assertEqual(output_shape, (batch_size, output_classes))
+
+if __name__ == '__main__':
+ tf.enable_eager_execution()
+ tf.test.main()
diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist.py b/tensorflow/contrib/eager/python/examples/gan/mnist.py
index cc9cf53410..b33243021b 100644
--- a/tensorflow/contrib/eager/python/examples/gan/mnist.py
+++ b/tensorflow/contrib/eager/python/examples/gan/mnist.py
@@ -214,7 +214,7 @@ def train_one_epoch(generator, discriminator, generator_optimizer,
total_generator_loss = 0.0
total_discriminator_loss = 0.0
- for (batch_index, images) in enumerate(tfe.Iterator(dataset)):
+ for (batch_index, images) in enumerate(dataset):
with tf.device('/cpu:0'):
tf.assign_add(step_counter, 1)
@@ -227,7 +227,10 @@ def train_one_epoch(generator, discriminator, generator_optimizer,
maxval=1.,
seed=batch_index)
- with tf.GradientTape(persistent=True) as g:
+ # we can use 2 tapes or a single persistent tape.
+ # Using two tapes is memory efficient since intermediate tensors can be
+ # released between the two .gradient() calls below
+ with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_images = generator(noise)
tf.contrib.summary.image(
'generated_images',
@@ -243,9 +246,10 @@ def train_one_epoch(generator, discriminator, generator_optimizer,
generator_loss_val = generator_loss(discriminator_gen_outputs)
total_generator_loss += generator_loss_val
- generator_grad = g.gradient(generator_loss_val, generator.variables)
- discriminator_grad = g.gradient(discriminator_loss_val,
- discriminator.variables)
+ generator_grad = gen_tape.gradient(generator_loss_val,
+ generator.variables)
+ discriminator_grad = disc_tape.gradient(discriminator_loss_val,
+ discriminator.variables)
generator_optimizer.apply_gradients(
zip(generator_grad, generator.variables))
diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb
new file mode 100644
index 0000000000..1a5a186e7a
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb
@@ -0,0 +1,1184 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "name": "image_captioning_with_attention.ipynb",
+ "version": "0.3.2",
+ "views": {},
+ "default_view": {},
+ "provenance": [
+ {
+ "file_id": "1HI8OK2sMjcx9CTWVn0122QAHOuXaOaMg",
+ "timestamp": 1530222436922
+ }
+ ],
+ "private_outputs": true,
+ "collapsed_sections": [],
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "accelerator": "GPU"
+ },
+ "cells": [
+ {
+ "metadata": {
+ "id": "K2s1A9eLRPEj",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "##### Copyright 2018 The TensorFlow Authors.\n",
+ "\n",
+ "Licensed under the Apache License, Version 2.0 (the \"License\").\n"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "Cffg2i257iMS",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "# Image Captioning with Attention\n",
+ "\n",
+ "<table class=\"tfo-notebook-buttons\" align=\"left\"><td>\n",
+ "<a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb\">\n",
+ " <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a> \n",
+ "</td><td>\n",
+ "<a target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb\"><img width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a></td></table>"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "QASbY_HGo4Lq",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "Image captioning is the task of generating a caption for an image. Given an image like this:\n",
+ "\n",
+ "![Man Surfing](https://tensorflow.org/images/surf.jpg) \n",
+ "\n",
+ "[Image Source](https://commons.wikimedia.org/wiki/Surfing#/media/File:Surfing_in_Hawaii.jpg), License: Public Domain\n",
+ "\n",
+ "Our goal is generate a caption, such as \"a surfer riding on a wave\". Here, we'll use an attention based model. This enables us to see which parts of the image the model focuses on as it generates a caption.\n",
+ "\n",
+ "![Prediction](https://tensorflow.org/images/imcap_prediction.png)\n",
+ "\n",
+ "This model architecture below is similar to [Show, Attend and Tell: Neural Image Caption Generation with Visual Attention](https://arxiv.org/abs/1502.03044). \n",
+ "\n",
+ "The code uses [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager), which you can learn more about in the linked guides.\n",
+ "\n",
+ "This notebook is an end-to-end example. If you run it, it will download the [MS-COCO](http://cocodataset.org/#home) dataset, preprocess and cache a subset of the images using Inception V3, train an encoder-decoder model, and use it to generate captions on new images.\n",
+ "\n",
+ "The code requires TensorFlow version >=1.9. If you're running this in [Colab]()\n",
+ "\n",
+ "In this example, we're training on a relatively small amount of data as an example. On a single P100 GPU, this example will take about ~2 hours to train. We train on the first 30,000 captions (corresponding to about ~20,000 images depending on shuffling, as there are multiple captions per image in the dataset)\n"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "U8l4RJ0XRPEm",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# Import TensorFlow and enable eager execution\n",
+ "# This code requires TensorFlow version >=1.9\n",
+ "import tensorflow as tf\n",
+ "tf.enable_eager_execution()\n",
+ "\n",
+ "# We'll generate plots of attention in order to see which parts of an image\n",
+ "# our model focuses on during captioning\n",
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "# Scikit-learn includes many helpful utilities\n",
+ "from sklearn.model_selection import train_test_split\n",
+ "from sklearn.utils import shuffle\n",
+ "\n",
+ "import re\n",
+ "import numpy as np\n",
+ "import os\n",
+ "import time\n",
+ "import json\n",
+ "from glob import glob\n",
+ "from PIL import Image\n",
+ "import pickle"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "b6qbGw8MRPE5",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Download and prepare the MS-COCO dataset\n",
+ "\n",
+ "We will use the [MS-COCO dataset](http://cocodataset.org/#home) to train our model. This dataset contains >82,000 images, each of which has been annotated with at least 5 different captions. The code code below will download and extract the dataset automatically. \n",
+ "\n",
+ "**Caution: large download ahead**. We'll use the training set, it's a 13GB file."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "krQuPYTtRPE7",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "annotation_zip = tf.keras.utils.get_file('captions.zip', \n",
+ " cache_subdir=os.path.abspath('.'),\n",
+ " origin = 'http://images.cocodataset.org/annotations/annotations_trainval2014.zip',\n",
+ " extract = True)\n",
+ "annotation_file = os.path.dirname(annotation_zip)+'/annotations/captions_train2014.json'\n",
+ "\n",
+ "name_of_zip = 'train2014.zip'\n",
+ "if not os.path.exists(os.path.abspath('.') + '/' + name_of_zip):\n",
+ " image_zip = tf.keras.utils.get_file(name_of_zip, \n",
+ " cache_subdir=os.path.abspath('.'),\n",
+ " origin = 'http://images.cocodataset.org/zips/train2014.zip',\n",
+ " extract = True)\n",
+ " PATH = os.path.dirname(image_zip)+'/train2014/'\n",
+ "else:\n",
+ " PATH = os.path.abspath('.')+'/train2014/'"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "aANEzb5WwSzg",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Optionally, limit the size of the training set for faster training\n",
+ "For this example, we'll select a subset of 30,000 captions and use these and the corresponding images to train our model. As always, captioning quality will improve if you choose to use more data."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "4G3b8x8_RPFD",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# read the json file\n",
+ "with open(annotation_file, 'r') as f:\n",
+ " annotations = json.load(f)\n",
+ "\n",
+ "# storing the captions and the image name in vectors\n",
+ "all_captions = []\n",
+ "all_img_name_vector = []\n",
+ "\n",
+ "for annot in annotations['annotations']:\n",
+ " caption = '<start> ' + annot['caption'] + ' <end>'\n",
+ " image_id = annot['image_id']\n",
+ " full_coco_image_path = PATH + 'COCO_train2014_' + '%012d.jpg' % (image_id)\n",
+ " \n",
+ " all_img_name_vector.append(full_coco_image_path)\n",
+ " all_captions.append(caption)\n",
+ "\n",
+ "# shuffling the captions and image_names together\n",
+ "# setting a random state\n",
+ "train_captions, img_name_vector = shuffle(all_captions,\n",
+ " all_img_name_vector,\n",
+ " random_state=1)\n",
+ "\n",
+ "# selecting the first 30000 captions from the shuffled set\n",
+ "num_examples = 30000\n",
+ "train_captions = train_captions[:num_examples]\n",
+ "img_name_vector = img_name_vector[:num_examples]"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "mPBMgK34RPFL",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "len(train_captions), len(all_captions)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "8cSW4u-ORPFQ",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Preprocess the images using InceptionV3\n",
+ "Next, we will use InceptionV3 (pretrained on Imagenet) to classify each image. We will extract features from the last convolutional layer. \n",
+ "\n",
+ "First, we will need to convert the images into the format inceptionV3 expects by:\n",
+ "* Resizing the image to (299, 299)\n",
+ "* Using the [preprocess_input](https://www.tensorflow.org/api_docs/python/tf/keras/applications/inception_v3/preprocess_input) method to place the pixels in the range of -1 to 1 (to match the format of the images used to train InceptionV3)."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "zXR0217aRPFR",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "def load_image(image_path):\n",
+ " img = tf.read_file(image_path)\n",
+ " img = tf.image.decode_jpeg(img, channels=3)\n",
+ " img = tf.image.resize_images(img, (299, 299))\n",
+ " img = tf.keras.applications.inception_v3.preprocess_input(img)\n",
+ " return img, image_path"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "MDvIu4sXRPFV",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Initialize InceptionV3 and load the pretrained Imagenet weights\n",
+ "\n",
+ "To do so, we'll create a tf.keras model where the output layer is the last convolutional layer in the InceptionV3 architecture. \n",
+ "* Each image is forwarded through the network and the vector that we get at the end is stored in a dictionary (image_name --> feature_vector). \n",
+ "* We use the last convolutional layer because we are using attention in this example. The shape of the output of this layer is ```8x8x2048```. \n",
+ "* We avoid doing this during training so it does not become a bottleneck. \n",
+ "* After all the images are passed through the network, we pickle the dictionary and save it to disk."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "RD3vW4SsRPFW",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "image_model = tf.keras.applications.InceptionV3(include_top=False, \n",
+ " weights='imagenet')\n",
+ "new_input = image_model.input\n",
+ "hidden_layer = image_model.layers[-1].output\n",
+ "\n",
+ "image_features_extract_model = tf.keras.Model(new_input, hidden_layer)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "rERqlR3WRPGO",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Caching the features extracted from InceptionV3\n",
+ "\n",
+ "We will pre-process each image with InceptionV3 and cache the output to disk. Caching the output in RAM would be faster but memory intensive, requiring 8 \\* 8 \\* 2048 floats per image. At the time of writing, this would exceed the memory limitations of Colab (although these may change, an instance appears to have about 12GB of memory currently). \n",
+ "\n",
+ "Performance could be improved with a more sophisticated caching strategy (e.g., by sharding the images to reduce random access disk I/O) at the cost of more code.\n",
+ "\n",
+ "This will take about 10 minutes to run in Colab with a GPU. If you'd like to see a progress bar, you could: install [tqdm](https://github.com/tqdm/tqdm) (```!pip install tqdm```), then change this line: \n",
+ "\n",
+ "```for img, path in image_dataset:``` \n",
+ "\n",
+ "to:\n",
+ "\n",
+ "```for img, path in tqdm(image_dataset):```."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "Dx_fvbVgRPGQ",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# getting the unique images\n",
+ "encode_train = sorted(set(img_name_vector))\n",
+ "\n",
+ "# feel free to change the batch_size according to your system configuration\n",
+ "image_dataset = tf.data.Dataset.from_tensor_slices(\n",
+ " encode_train).map(load_image).batch(16)\n",
+ "\n",
+ "for img, path in image_dataset:\n",
+ " batch_features = image_features_extract_model(img)\n",
+ " batch_features = tf.reshape(batch_features, \n",
+ " (batch_features.shape[0], -1, batch_features.shape[3]))\n",
+ "\n",
+ " for bf, p in zip(batch_features, path):\n",
+ " path_of_feature = p.numpy().decode(\"utf-8\")\n",
+ " np.save(path_of_feature, bf.numpy())"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "nyqH3zFwRPFi",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Preprocess and tokenize the captions\n",
+ "\n",
+ "* First, we'll tokenize the captions (e.g., by splitting on spaces). This will give us a vocabulary of all the unique words in the data (e.g., \"surfing\", \"football\", etc).\n",
+ "* Next, we'll limit the vocabulary size to the top 5,000 words to save memory. We'll replace all other words with the token \"UNK\" (for unknown).\n",
+ "* Finally, we create a word --> index mapping and vice-versa.\n",
+ "* We will then pad all sequences to the be same length as the longest one. "
+ ]
+ },
+ {
+ "metadata": {
+ "id": "HZfK8RhQRPFj",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# This will find the maximum length of any caption in our dataset\n",
+ "def calc_max_length(tensor):\n",
+ " return max(len(t) for t in tensor)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "oJGE34aiRPFo",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# The steps above is a general process of dealing with text processing\n",
+ "\n",
+ "# choosing the top 5000 words from the vocabulary\n",
+ "top_k = 5000\n",
+ "tokenizer = tf.keras.preprocessing.text.Tokenizer(num_words=top_k, \n",
+ " oov_token=\"<unk>\", \n",
+ " filters='!\"#$%&()*+.,-/:;=?@[\\]^_`{|}~ ')\n",
+ "tokenizer.fit_on_texts(train_captions)\n",
+ "train_seqs = tokenizer.texts_to_sequences(train_captions)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "8Q44tNQVRPFt",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "tokenizer.word_index = {key:value for key, value in tokenizer.word_index.items() if value <= top_k}\n",
+ "# putting <unk> token in the word2idx dictionary\n",
+ "tokenizer.word_index[tokenizer.oov_token] = top_k + 1\n",
+ "tokenizer.word_index['<pad>'] = 0"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "0fpJb5ojRPFv",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# creating the tokenized vectors\n",
+ "train_seqs = tokenizer.texts_to_sequences(train_captions)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "olQArbgbRPF1",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# creating a reverse mapping (index -> word)\n",
+ "index_word = {value:key for key, value in tokenizer.word_index.items()}"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "AidglIZVRPF4",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# padding each vector to the max_length of the captions\n",
+ "# if the max_length parameter is not provided, pad_sequences calculates that automatically\n",
+ "cap_vector = tf.keras.preprocessing.sequence.pad_sequences(train_seqs, padding='post')"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "gL0wkttkRPGA",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# calculating the max_length \n",
+ "# used to store the attention weights\n",
+ "max_length = calc_max_length(train_seqs)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "M3CD75nDpvTI",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Split the data into training and testing"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "iS7DDMszRPGF",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# Create training and validation sets using 80-20 split\n",
+ "img_name_train, img_name_val, cap_train, cap_val = train_test_split(img_name_vector, \n",
+ " cap_vector, \n",
+ " test_size=0.2, \n",
+ " random_state=0)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "XmViPkRFRPGH",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "len(img_name_train), len(cap_train), len(img_name_val), len(cap_val)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "uEWM9xrYcg45",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Our images and captions are ready! Next, let's create a tf.data dataset to use for training our model.\n",
+ "\n"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "Q3TnZ1ToRPGV",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# feel free to change these parameters according to your system's configuration\n",
+ "\n",
+ "BATCH_SIZE = 64\n",
+ "BUFFER_SIZE = 1000\n",
+ "embedding_dim = 256\n",
+ "units = 512\n",
+ "vocab_size = len(tokenizer.word_index)\n",
+ "# shape of the vector extracted from InceptionV3 is (64, 2048)\n",
+ "# these two variables represent that\n",
+ "features_shape = 2048\n",
+ "attention_features_shape = 64"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "SmZS2N0bXG3T",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# loading the numpy files \n",
+ "def map_func(img_name, cap):\n",
+ " img_tensor = np.load(img_name.decode('utf-8')+'.npy')\n",
+ " return img_tensor, cap"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "FDF_Nm3tRPGZ",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "dataset = tf.data.Dataset.from_tensor_slices((img_name_train, cap_train))\n",
+ "\n",
+ "# using map to load the numpy files in parallel\n",
+ "# NOTE: Be sure to set num_parallel_calls to the number of CPU cores you have\n",
+ "# https://www.tensorflow.org/api_docs/python/tf/py_func\n",
+ "dataset = dataset.map(lambda item1, item2: tf.py_func(\n",
+ " map_func, [item1, item2], [tf.float32, tf.int32]), num_parallel_calls=8)\n",
+ "\n",
+ "# shuffling and batching\n",
+ "dataset = dataset.shuffle(BUFFER_SIZE)\n",
+ "# https://www.tensorflow.org/api_docs/python/tf/contrib/data/batch_and_drop_remainder\n",
+ "dataset = dataset.batch(BATCH_SIZE)\n",
+ "dataset = dataset.prefetch(1)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "nrvoDphgRPGd",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Model\n",
+ "\n",
+ "Fun fact, the decoder below is identical to the one in the example for [Neural Machine Translation with Attention]( https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb).\n",
+ "\n",
+ "The model architecture is inspired by the [Show, Attend and Tell](https://arxiv.org/pdf/1502.03044.pdf) paper.\n",
+ "\n",
+ "* In this example, we extract the features from the lower convolutional layer of InceptionV3 giving us a vector of shape (8, 8, 2048). \n",
+ "* We squash that to a shape of (64, 2048).\n",
+ "* This vector is then passed through the CNN Encoder(which consists of a single Fully connected layer).\n",
+ "* The RNN(here GRU) attends over the image to predict the next word."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "AAppCGLKRPGd",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "def gru(units):\n",
+ " # If you have a GPU, we recommend using the CuDNNGRU layer (it provides a \n",
+ " # significant speedup).\n",
+ " if tf.test.is_gpu_available():\n",
+ " return tf.keras.layers.CuDNNGRU(units, \n",
+ " return_sequences=True, \n",
+ " return_state=True, \n",
+ " recurrent_initializer='glorot_uniform')\n",
+ " else:\n",
+ " return tf.keras.layers.GRU(units, \n",
+ " return_sequences=True, \n",
+ " return_state=True, \n",
+ " recurrent_activation='sigmoid', \n",
+ " recurrent_initializer='glorot_uniform')"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "ja2LFTMSdeV3",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "class BahdanauAttention(tf.keras.Model):\n",
+ " def __init__(self, units):\n",
+ " super(BahdanauAttention, self).__init__()\n",
+ " self.W1 = tf.keras.layers.Dense(units)\n",
+ " self.W2 = tf.keras.layers.Dense(units)\n",
+ " self.V = tf.keras.layers.Dense(1)\n",
+ " \n",
+ " def call(self, features, hidden):\n",
+ " # features(CNN_encoder output) shape == (batch_size, 64, embedding_dim)\n",
+ " \n",
+ " # hidden shape == (batch_size, hidden_size)\n",
+ " # hidden_with_time_axis shape == (batch_size, 1, hidden_size)\n",
+ " hidden_with_time_axis = tf.expand_dims(hidden, 1)\n",
+ " \n",
+ " # score shape == (batch_size, 64, hidden_size)\n",
+ " score = tf.nn.tanh(self.W1(features) + self.W2(hidden_with_time_axis))\n",
+ " \n",
+ " # attention_weights shape == (batch_size, 64, 1)\n",
+ " # we get 1 at the last axis because we are applying score to self.V\n",
+ " attention_weights = tf.nn.softmax(self.V(score), axis=1)\n",
+ " \n",
+ " # context_vector shape after sum == (batch_size, hidden_size)\n",
+ " context_vector = attention_weights * features\n",
+ " context_vector = tf.reduce_sum(context_vector, axis=1)\n",
+ " \n",
+ " return context_vector, attention_weights"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "AZ7R1RxHRPGf",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "class CNN_Encoder(tf.keras.Model):\n",
+ " # Since we have already extracted the features and dumped it using pickle\n",
+ " # This encoder passes those features through a Fully connected layer\n",
+ " def __init__(self, embedding_dim):\n",
+ " super(CNN_Encoder, self).__init__()\n",
+ " # shape after fc == (batch_size, 64, embedding_dim)\n",
+ " self.fc = tf.keras.layers.Dense(embedding_dim)\n",
+ " \n",
+ " def call(self, x):\n",
+ " x = self.fc(x)\n",
+ " x = tf.nn.relu(x)\n",
+ " return x"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "V9UbGQmERPGi",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "class RNN_Decoder(tf.keras.Model):\n",
+ " def __init__(self, embedding_dim, units, vocab_size):\n",
+ " super(RNN_Decoder, self).__init__()\n",
+ " self.units = units\n",
+ "\n",
+ " self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n",
+ " self.gru = gru(self.units)\n",
+ " self.fc1 = tf.keras.layers.Dense(self.units)\n",
+ " self.fc2 = tf.keras.layers.Dense(vocab_size)\n",
+ " \n",
+ " self.attention = BahdanauAttention(self.units)\n",
+ " \n",
+ " def call(self, x, features, hidden):\n",
+ " # defining attention as a separate model\n",
+ " context_vector, attention_weights = self.attention(features, hidden)\n",
+ " \n",
+ " # x shape after passing through embedding == (batch_size, 1, embedding_dim)\n",
+ " x = self.embedding(x)\n",
+ " \n",
+ " # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)\n",
+ " x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)\n",
+ " \n",
+ " # passing the concatenated vector to the GRU\n",
+ " output, state = self.gru(x)\n",
+ " \n",
+ " # shape == (batch_size, max_length, hidden_size)\n",
+ " x = self.fc1(output)\n",
+ " \n",
+ " # x shape == (batch_size * max_length, hidden_size)\n",
+ " x = tf.reshape(x, (-1, x.shape[2]))\n",
+ " \n",
+ " # output shape == (batch_size * max_length, vocab)\n",
+ " x = self.fc2(x)\n",
+ "\n",
+ " return x, state, attention_weights\n",
+ "\n",
+ " def reset_state(self, batch_size):\n",
+ " return tf.zeros((batch_size, self.units))"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "Qs_Sr03wRPGk",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "encoder = CNN_Encoder(embedding_dim)\n",
+ "decoder = RNN_Decoder(embedding_dim, units, vocab_size)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "-bYN7xA0RPGl",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "optimizer = tf.train.AdamOptimizer()\n",
+ "\n",
+ "# We are masking the loss calculated for padding\n",
+ "def loss_function(real, pred):\n",
+ " mask = 1 - np.equal(real, 0)\n",
+ " loss_ = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=real, logits=pred) * mask\n",
+ " return tf.reduce_mean(loss_)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "PHod7t72RPGn",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Training\n",
+ "\n",
+ "* We extract the features stored in the respective `.npy` files and then pass those features through the encoder.\n",
+ "* The encoder output, hidden state(initialized to 0) and the decoder input (which is the start token) is passed to the decoder.\n",
+ "* The decoder returns the predictions and the decoder hidden state.\n",
+ "* The decoder hidden state is then passed back into the model and the predictions are used to calculate the loss.\n",
+ "* Use teacher forcing to decide the next input to the decoder.\n",
+ "* Teacher forcing is the technique where the target word is passed as the next input to the decoder.\n",
+ "* The final step is to calculate the gradients and apply it to the optimizer and backpropagate.\n"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "Vt4WZ5mhJE-E",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# adding this in a separate cell because if you run the training cell \n",
+ "# many times, the loss_plot array will be reset\n",
+ "loss_plot = []"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "UlA4VIQpRPGo",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "EPOCHS = 20\n",
+ "\n",
+ "for epoch in range(EPOCHS):\n",
+ " start = time.time()\n",
+ " total_loss = 0\n",
+ " \n",
+ " for (batch, (img_tensor, target)) in enumerate(dataset):\n",
+ " loss = 0\n",
+ " \n",
+ " # initializing the hidden state for each batch\n",
+ " # because the captions are not related from image to image\n",
+ " hidden = decoder.reset_state(batch_size=target.shape[0])\n",
+ "\n",
+ " dec_input = tf.expand_dims([tokenizer.word_index['<start>']] * BATCH_SIZE, 1)\n",
+ " \n",
+ " with tf.GradientTape() as tape:\n",
+ " features = encoder(img_tensor)\n",
+ " \n",
+ " for i in range(1, target.shape[1]):\n",
+ " # passing the features through the decoder\n",
+ " predictions, hidden, _ = decoder(dec_input, features, hidden)\n",
+ "\n",
+ " loss += loss_function(target[:, i], predictions)\n",
+ " \n",
+ " # using teacher forcing\n",
+ " dec_input = tf.expand_dims(target[:, i], 1)\n",
+ " \n",
+ " total_loss += (loss / int(target.shape[1]))\n",
+ " \n",
+ " variables = encoder.variables + decoder.variables\n",
+ " \n",
+ " gradients = tape.gradient(loss, variables) \n",
+ " \n",
+ " optimizer.apply_gradients(zip(gradients, variables), tf.train.get_or_create_global_step())\n",
+ " \n",
+ " if batch % 100 == 0:\n",
+ " print ('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1, \n",
+ " batch, \n",
+ " loss.numpy() / int(target.shape[1])))\n",
+ " # storing the epoch end loss value to plot later\n",
+ " loss_plot.append(total_loss / len(cap_vector))\n",
+ " \n",
+ " print ('Epoch {} Loss {:.6f}'.format(epoch + 1, \n",
+ " total_loss/len(cap_vector)))\n",
+ " print ('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "1Wm83G-ZBPcC",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "plt.plot(loss_plot)\n",
+ "plt.xlabel('Epochs')\n",
+ "plt.ylabel('Loss')\n",
+ "plt.title('Loss Plot')\n",
+ "plt.show()"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "xGvOcLQKghXN",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Caption!\n",
+ "\n",
+ "* The evaluate function is similar to the training loop, except we don't use teacher forcing here. The input to the decoder at each time step is its previous predictions along with the hidden state and the encoder output.\n",
+ "* Stop predicting when the model predicts the end token.\n",
+ "* And store the attention weights for every time step."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "RCWpDtyNRPGs",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "def evaluate(image):\n",
+ " attention_plot = np.zeros((max_length, attention_features_shape))\n",
+ "\n",
+ " hidden = decoder.reset_state(batch_size=1)\n",
+ "\n",
+ " temp_input = tf.expand_dims(load_image(image)[0], 0)\n",
+ " img_tensor_val = image_features_extract_model(temp_input)\n",
+ " img_tensor_val = tf.reshape(img_tensor_val, (img_tensor_val.shape[0], -1, img_tensor_val.shape[3]))\n",
+ "\n",
+ " features = encoder(img_tensor_val)\n",
+ "\n",
+ " dec_input = tf.expand_dims([tokenizer.word_index['<start>']], 0)\n",
+ " result = []\n",
+ "\n",
+ " for i in range(max_length):\n",
+ " predictions, hidden, attention_weights = decoder(dec_input, features, hidden)\n",
+ "\n",
+ " attention_plot[i] = tf.reshape(attention_weights, (-1, )).numpy()\n",
+ "\n",
+ " predicted_id = tf.multinomial(tf.exp(predictions), num_samples=1)[0][0].numpy()\n",
+ " result.append(index_word[predicted_id])\n",
+ "\n",
+ " if index_word[predicted_id] == '<end>':\n",
+ " return result, attention_plot\n",
+ "\n",
+ " dec_input = tf.expand_dims([predicted_id], 0)\n",
+ "\n",
+ " attention_plot = attention_plot[:len(result), :]\n",
+ " return result, attention_plot"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "fD_y7PD6RPGt",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "def plot_attention(image, result, attention_plot):\n",
+ " temp_image = np.array(Image.open(image))\n",
+ "\n",
+ " fig = plt.figure(figsize=(10, 10))\n",
+ " \n",
+ " len_result = len(result)\n",
+ " for l in range(len_result):\n",
+ " temp_att = np.resize(attention_plot[l], (8, 8))\n",
+ " ax = fig.add_subplot(len_result//2, len_result//2, l+1)\n",
+ " ax.set_title(result[l])\n",
+ " img = ax.imshow(temp_image)\n",
+ " ax.imshow(temp_att, cmap='gray', alpha=0.6, extent=img.get_extent())\n",
+ "\n",
+ " plt.tight_layout()\n",
+ " plt.show()"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "io7ws3ReRPGv",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# captions on the validation set\n",
+ "rid = np.random.randint(0, len(img_name_val))\n",
+ "image = img_name_val[rid]\n",
+ "real_caption = ' '.join([index_word[i] for i in cap_val[rid] if i not in [0]])\n",
+ "result, attention_plot = evaluate(image)\n",
+ "\n",
+ "print ('Real Caption:', real_caption)\n",
+ "print ('Prediction Caption:', ' '.join(result))\n",
+ "plot_attention(image, result, attention_plot)\n",
+ "# opening the image\n",
+ "Image.open(img_name_val[rid])"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "Rprk3HEvZuxb",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Try it on your own images\n",
+ "For fun, below we've provided a method you can use to caption your own images with the model we've just trained. Keep in mind, it was trained on a relatively small amount of data, and your images may be different from the training data (so be prepared for weird results!)\n"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "9Psd1quzaAWg",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "image_url = 'https://tensorflow.org/images/surf.jpg'\n",
+ "image_extension = image_url[-4:]\n",
+ "image_path = tf.keras.utils.get_file('image'+image_extension, \n",
+ " origin=image_url)\n",
+ "\n",
+ "result, attention_plot = evaluate(image_path)\n",
+ "print ('Prediction Caption:', ' '.join(result))\n",
+ "plot_attention(image_path, result, attention_plot)\n",
+ "# opening the image\n",
+ "Image.open(image_path)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "VJZXyJco6uLO",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "# Next steps\n",
+ "\n",
+ "Congrats! You've just trained an image captioning model with attention. Next, we recommend taking a look at this example [Neural Machine Translation with Attention]( https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb). It uses a similar architecture to translate between Spanish and English sentences. You can also experiment with training the code in this notebook on a different dataset."
+ ]
+ }
+ ]
+}
diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb
new file mode 100644
index 0000000000..6be09f98df
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb
@@ -0,0 +1,689 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "name": "text_generation.ipynb",
+ "version": "0.3.2",
+ "views": {},
+ "default_view": {},
+ "provenance": [],
+ "private_outputs": true,
+ "collapsed_sections": [],
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "accelerator": "GPU"
+ },
+ "cells": [
+ {
+ "metadata": {
+ "id": "hcD2nPQvPOFM",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "##### Copyright 2018 The TensorFlow Authors.\n",
+ "\n",
+ "Licensed under the Apache License, Version 2.0 (the \"License\").\n",
+ "\n",
+ "# Text Generation using a RNN\n",
+ "\n",
+ "<table class=\"tfo-notebook-buttons\" align=\"left\"><td>\n",
+ "<a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb\">\n",
+ " <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a> \n",
+ "</td><td>\n",
+ "<a target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb\"><img width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on Github</a></td></table>"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "BwpJ5IffzRG6",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "This notebook demonstrates how to generate text using an RNN using [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager). If you like, you can write a similar [model](https://github.com/fchollet/deep-learning-with-python-notebooks/blob/master/8.1-text-generation-with-lstm.ipynb) using less code. Here, we show a lower-level impementation that's useful to understand as prework before diving in to deeper examples in a similar, like [Neural Machine Translation with Attention](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb).\n",
+ "\n",
+ "This notebook is an end-to-end example. When you run it, it will download a dataset of Shakespeare's writing. We'll use a collection of plays, borrowed from Andrej Karpathy's excellent [The Unreasonable Effectiveness of Recurrent Neural Networks](http://karpathy.github.io/2015/05/21/rnn-effectiveness/). The notebook will train a model, and use it to generate sample output.\n",
+ " \n",
+ "Here is the output(with start string='w') after training a single layer GRU for 30 epochs with the default settings below:\n",
+ "\n",
+ "```\n",
+ "were to the death of him\n",
+ "And nothing of the field in the view of hell,\n",
+ "When I said, banish him, I will not burn thee that would live.\n",
+ "\n",
+ "HENRY BOLINGBROKE:\n",
+ "My gracious uncle--\n",
+ "\n",
+ "DUKE OF YORK:\n",
+ "As much disgraced to the court, the gods them speak,\n",
+ "And now in peace himself excuse thee in the world.\n",
+ "\n",
+ "HORTENSIO:\n",
+ "Madam, 'tis not the cause of the counterfeit of the earth,\n",
+ "And leave me to the sun that set them on the earth\n",
+ "And leave the world and are revenged for thee.\n",
+ "\n",
+ "GLOUCESTER:\n",
+ "I would they were talking with the very name of means\n",
+ "To make a puppet of a guest, and therefore, good Grumio,\n",
+ "Nor arm'd to prison, o' the clouds, of the whole field,\n",
+ "With the admire\n",
+ "With the feeding of thy chair, and we have heard it so,\n",
+ "I thank you, sir, he is a visor friendship with your silly your bed.\n",
+ "\n",
+ "SAMPSON:\n",
+ "I do desire to live, I pray: some stand of the minds, make thee remedies\n",
+ "With the enemies of my soul.\n",
+ "\n",
+ "MENENIUS:\n",
+ "I'll keep the cause of my mistress.\n",
+ "\n",
+ "POLIXENES:\n",
+ "My brother Marcius!\n",
+ "\n",
+ "Second Servant:\n",
+ "Will't ple\n",
+ "```\n",
+ "\n",
+ "Of course, while some of the sentences are grammatical, most do not make sense. But, consider:\n",
+ "\n",
+ "* Our model is character based (when we began training, it did not yet know how to spell a valid English word, or that words were even a unit of text).\n",
+ "\n",
+ "* The structure of the output resembles a play (blocks begin with a speaker name, in all caps similar to the original text). Sentences generally end with a period. If you look at the text from a distance (or don't read the invididual words too closely, it appears as if it's an excerpt from a play).\n",
+ "\n",
+ "As a next step, you can experiment training the model on a different dataset - any large text file(ASCII) will do, and you can modify a single line of code below to make that change. Have fun!\n"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "R3p22DBDsaCA",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Install unidecode library\n",
+ "A helpful library to convert unicode to ASCII."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "wZ6LOM12wKGH",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "!pip install unidecode"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "WGyKZj3bzf9p",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Import tensorflow and enable eager execution."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "yG_n40gFzf9s",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# Import TensorFlow >= 1.9 and enable eager execution\n",
+ "import tensorflow as tf\n",
+ "\n",
+ "# Note: Once you enable eager execution, it cannot be disabled. \n",
+ "tf.enable_eager_execution()\n",
+ "\n",
+ "import numpy as np\n",
+ "import re\n",
+ "import random\n",
+ "import unidecode\n",
+ "import time"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "EHDoRoc5PKWz",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Download the dataset\n",
+ "\n",
+ "In this example, we will use the [shakespeare dataset](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt). You can use any other dataset that you like.\n",
+ "\n"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "pD_55cOxLkAb",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "path_to_file = tf.keras.utils.get_file('shakespeare.txt', 'https://storage.googleapis.com/yashkatariya/shakespeare.txt')"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "UHjdCjDuSvX_",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Read the dataset\n",
+ "\n"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "-E5JvY3wzf94",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "text = unidecode.unidecode(open(path_to_file).read())\n",
+ "# length of text is the number of characters in it\n",
+ "print (len(text))"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "Il9ww98izf-D",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "Creating dictionaries to map from characters to their indices and vice-versa, which will be used to vectorize the inputs"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "IalZLbvOzf-F",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# unique contains all the unique characters in the file\n",
+ "unique = sorted(set(text))\n",
+ "\n",
+ "# creating a mapping from unique characters to indices\n",
+ "char2idx = {u:i for i, u in enumerate(unique)}\n",
+ "idx2char = {i:u for i, u in enumerate(unique)}"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "1v_qUYfAzf-I",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# setting the maximum length sentence we want for a single input in characters\n",
+ "max_length = 100\n",
+ "\n",
+ "# length of the vocabulary in chars\n",
+ "vocab_size = len(unique)\n",
+ "\n",
+ "# the embedding dimension \n",
+ "embedding_dim = 256\n",
+ "\n",
+ "# number of RNN (here GRU) units\n",
+ "units = 1024\n",
+ "\n",
+ "# batch size \n",
+ "BATCH_SIZE = 64\n",
+ "\n",
+ "# buffer size to shuffle our dataset\n",
+ "BUFFER_SIZE = 10000"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "LFjSVAlWzf-N",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Creating the input and output tensors\n",
+ "\n",
+ "Vectorizing the input and the target text because our model cannot understand strings only numbers.\n",
+ "\n",
+ "But first, we need to create the input and output vectors.\n",
+ "Remember the max_length we set above, we will use it here. We are creating **max_length** chunks of input, where each input vector is all the characters in that chunk except the last and the target vector is all the characters in that chunk except the first.\n",
+ "\n",
+ "For example, consider that the string = 'tensorflow' and the max_length is 9\n",
+ "\n",
+ "So, the `input = 'tensorflo'` and `output = 'ensorflow'`\n",
+ "\n",
+ "After creating the vectors, we convert each character into numbers using the **char2idx** dictionary we created above."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "0UHJDA39zf-O",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "input_text = []\n",
+ "target_text = []\n",
+ "\n",
+ "for f in range(0, len(text)-max_length, max_length):\n",
+ " inps = text[f:f+max_length]\n",
+ " targ = text[f+1:f+1+max_length]\n",
+ "\n",
+ " input_text.append([char2idx[i] for i in inps])\n",
+ " target_text.append([char2idx[t] for t in targ])\n",
+ " \n",
+ "print (np.array(input_text).shape)\n",
+ "print (np.array(target_text).shape)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "MJdfPmdqzf-R",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Creating batches and shuffling them using tf.data"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "p2pGotuNzf-S",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "dataset = tf.data.Dataset.from_tensor_slices((input_text, target_text)).shuffle(BUFFER_SIZE)\n",
+ "dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(BATCH_SIZE))"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "m8gPwEjRzf-Z",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Creating the model\n",
+ "\n",
+ "We use the Model Subclassing API which gives us full flexibility to create the model and change it however we like. We use 3 layers to define our model.\n",
+ "\n",
+ "* Embedding layer\n",
+ "* GRU layer (you can use an LSTM layer here)\n",
+ "* Fully connected layer"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "P3KTiiInzf-a",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "class Model(tf.keras.Model):\n",
+ " def __init__(self, vocab_size, embedding_dim, units, batch_size):\n",
+ " super(Model, self).__init__()\n",
+ " self.units = units\n",
+ " self.batch_sz = batch_size\n",
+ "\n",
+ " self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n",
+ "\n",
+ " if tf.test.is_gpu_available():\n",
+ " self.gru = tf.keras.layers.CuDNNGRU(self.units, \n",
+ " return_sequences=True, \n",
+ " return_state=True, \n",
+ " recurrent_initializer='glorot_uniform')\n",
+ " else:\n",
+ " self.gru = tf.keras.layers.GRU(self.units, \n",
+ " return_sequences=True, \n",
+ " return_state=True, \n",
+ " recurrent_activation='sigmoid', \n",
+ " recurrent_initializer='glorot_uniform')\n",
+ "\n",
+ " self.fc = tf.keras.layers.Dense(vocab_size)\n",
+ " \n",
+ " def call(self, x, hidden):\n",
+ " x = self.embedding(x)\n",
+ "\n",
+ " # output shape == (batch_size, max_length, hidden_size) \n",
+ " # states shape == (batch_size, hidden_size)\n",
+ "\n",
+ " # states variable to preserve the state of the model\n",
+ " # this will be used to pass at every step to the model while training\n",
+ " output, states = self.gru(x, initial_state=hidden)\n",
+ "\n",
+ "\n",
+ " # reshaping the output so that we can pass it to the Dense layer\n",
+ " # after reshaping the shape is (batch_size * max_length, hidden_size)\n",
+ " output = tf.reshape(output, (-1, output.shape[2]))\n",
+ "\n",
+ " # The dense layer will output predictions for every time_steps(max_length)\n",
+ " # output shape after the dense layer == (max_length * batch_size, vocab_size)\n",
+ " x = self.fc(output)\n",
+ "\n",
+ " return x, states"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "trpqTWyvk0nr",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Call the model and set the optimizer and the loss function"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "7t2XrzEOzf-e",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "model = Model(vocab_size, embedding_dim, units, BATCH_SIZE)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "dkjWIATszf-h",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "optimizer = tf.train.AdamOptimizer()\n",
+ "\n",
+ "# using sparse_softmax_cross_entropy so that we don't have to create one-hot vectors\n",
+ "def loss_function(real, preds):\n",
+ " return tf.losses.sparse_softmax_cross_entropy(labels=real, logits=preds)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "lPrP0XMUzf-p",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Train the model\n",
+ "\n",
+ "Here we will use a custom training loop with the help of GradientTape()\n",
+ "\n",
+ "* We initialize the hidden state of the model with zeros and shape == (batch_size, number of rnn units). We do this by calling the function defined while creating the model.\n",
+ "\n",
+ "* Next, we iterate over the dataset(batch by batch) and calculate the **predictions and the hidden states** associated with that input.\n",
+ "\n",
+ "* There are a lot of interesting things happening here.\n",
+ " * The model gets hidden state(initialized with 0), lets call that **H0** and the first batch of input, lets call that **I0**.\n",
+ " * The model then returns the predictions **P1** and **H1**.\n",
+ " * For the next batch of input, the model receives **I1** and **H1**.\n",
+ " * The interesting thing here is that we pass **H1** to the model with **I1** which is how the model learns. The context learned from batch to batch is contained in the **hidden state**.\n",
+ " * We continue doing this until the dataset is exhausted and then we start a new epoch and repeat this.\n",
+ "\n",
+ "* After calculating the predictions, we calculate the **loss** using the loss function defined above. Then we calculate the gradients of the loss with respect to the model variables(input)\n",
+ "\n",
+ "* Finally, we take a step in that direction with the help of the optimizer using the apply_gradients function.\n",
+ "\n",
+ "Note:- If you are running this notebook in Colab which has a **Tesla K80 GPU** it takes about 23 seconds per epoch.\n"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "d4tSNwymzf-q",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# Training step\n",
+ "\n",
+ "EPOCHS = 30\n",
+ "\n",
+ "for epoch in range(EPOCHS):\n",
+ " start = time.time()\n",
+ " \n",
+ " # initializing the hidden state at the start of every epoch\n",
+ " hidden = model.reset_states()\n",
+ " \n",
+ " for (batch, (inp, target)) in enumerate(dataset):\n",
+ " with tf.GradientTape() as tape:\n",
+ " # feeding the hidden state back into the model\n",
+ " # This is the interesting step\n",
+ " predictions, hidden = model(inp, hidden)\n",
+ " \n",
+ " # reshaping the target because that's how the \n",
+ " # loss function expects it\n",
+ " target = tf.reshape(target, (-1,))\n",
+ " loss = loss_function(target, predictions)\n",
+ " \n",
+ " grads = tape.gradient(loss, model.variables)\n",
+ " optimizer.apply_gradients(zip(grads, model.variables), global_step=tf.train.get_or_create_global_step())\n",
+ "\n",
+ " if batch % 100 == 0:\n",
+ " print ('Epoch {} Batch {} Loss {:.4f}'.format(epoch+1,\n",
+ " batch,\n",
+ " loss))\n",
+ " \n",
+ " print ('Epoch {} Loss {:.4f}'.format(epoch+1, loss))\n",
+ " print('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "DjGz1tDkzf-u",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Predicting using our trained model\n",
+ "\n",
+ "The below code block is used to generated the text\n",
+ "\n",
+ "* We start by choosing a start string and initializing the hidden state and setting the number of characters we want to generate.\n",
+ "\n",
+ "* We get predictions using the start_string and the hidden state\n",
+ "\n",
+ "* Then we use a multinomial distribution to calculate the index of the predicted word. **We use this predicted word as our next input to the model**\n",
+ "\n",
+ "* **The hidden state returned by the model is fed back into the model so that it now has more context rather than just one word.** After we predict the next word, the modified hidden states are again fed back into the model, which is how it learns as it gets more context from the previously predicted words.\n",
+ "\n",
+ "* If you see the predictions, the model knows when to capitalize, make paragraphs and the text follows a shakespeare style of writing which is pretty awesome!"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "WvuwZBX5Ogfd",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# Evaluation step(generating text using the model learned)\n",
+ "\n",
+ "# number of characters to generate\n",
+ "num_generate = 1000\n",
+ "\n",
+ "# You can change the start string to experiment\n",
+ "start_string = 'Q'\n",
+ "# converting our start string to numbers(vectorizing!) \n",
+ "input_eval = [char2idx[s] for s in start_string]\n",
+ "input_eval = tf.expand_dims(input_eval, 0)\n",
+ "\n",
+ "# empty string to store our results\n",
+ "text_generated = ''\n",
+ "\n",
+ "# low temperatures results in more predictable text.\n",
+ "# higher temperatures results in more surprising text\n",
+ "# experiment to find the best setting\n",
+ "temperature = 1.0\n",
+ "\n",
+ "# hidden state shape == (batch_size, number of rnn units); here batch size == 1\n",
+ "hidden = [tf.zeros((1, units))]\n",
+ "for i in range(num_generate):\n",
+ " predictions, hidden = model(input_eval, hidden)\n",
+ "\n",
+ " # using a multinomial distribution to predict the word returned by the model\n",
+ " predictions = predictions / temperature\n",
+ " predicted_id = tf.multinomial(tf.exp(predictions), num_samples=1)[0][0].numpy()\n",
+ " \n",
+ " # We pass the predicted word as the next input to the model\n",
+ " # along with the previous hidden state\n",
+ " input_eval = tf.expand_dims([predicted_id], 0)\n",
+ " \n",
+ " text_generated += idx2char[predicted_id]\n",
+ "\n",
+ "print (start_string + text_generated)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "AM2Uma_-yVIq",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Next steps\n",
+ "\n",
+ "* Change the start string to a different character, or the start of a sentence.\n",
+ "* Experiment with training on a different, or with different parameters. [Project Gutenberg](http://www.gutenberg.org/ebooks/100), for example, contains a large collection of books.\n",
+ "* Experiment with the temperature parameter.\n",
+ "* Add another RNN layer.\n"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "gtEd86sX5cB2",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ ""
+ ],
+ "execution_count": 0,
+ "outputs": []
+ }
+ ]
+}
diff --git a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb
index 34ce5e0cc3..1f66d7e752 100644
--- a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb
+++ b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb
@@ -42,10 +42,10 @@
"# Neural Machine Translation with Attention\n",
"\n",
"<table class=\"tfo-notebook-buttons\" align=\"left\"><td>\n",
- "<a target=\"_blank\" href=\"https://colab.sandbox.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb\">\n",
- " <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /><span>Run in Google Colab</span></a> \n",
+ "<a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb\">\n",
+ " <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a> \n",
"</td><td>\n",
- "<a target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb\"><img width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /><span>View source on GitHub</span></a></td></table>"
+ "<a target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb\"><img width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a></td></table>"
]
},
{
diff --git a/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb
index a18882fafa..7c0f9b5b81 100644
--- a/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb
+++ b/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb
@@ -75,10 +75,10 @@
"cell_type": "markdown",
"source": [
"<table class=\"tfo-notebook-buttons\" align=\"left\"><td>\n",
- "<a target=\"_blank\" href=\"https://colab.sandbox.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb\">\n",
- " <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /><span>Run in Google Colab</span></a>\n",
+ "<a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb\">\n",
+ " <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
"</td><td>\n",
- "<a target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb\"><img width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /><span>View source on GitHub</span></a></td></table>"
+ "<a target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb\"><img width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a></td></table>"
]
},
{
diff --git a/tensorflow/contrib/eager/python/examples/notebooks/custom_layers.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/custom_layers.ipynb
index 54fbf2a7e1..a0bbbb6123 100644
--- a/tensorflow/contrib/eager/python/examples/notebooks/custom_layers.ipynb
+++ b/tensorflow/contrib/eager/python/examples/notebooks/custom_layers.ipynb
@@ -75,10 +75,10 @@
"cell_type": "markdown",
"source": [
"<table class=\"tfo-notebook-buttons\" align=\"left\"><td>\n",
- "<a target=\"_blank\" href=\"https://colab.sandbox.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/custom_layers.ipynb\">\n",
- " <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /><span>Run in Google Colab</span></a>\n",
+ "<a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/custom_layers.ipynb\">\n",
+ " <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
"</td><td>\n",
- "<a target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/custom_layers.ipynb\"><img width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /><span>View source on GitHub</span></a></td></table>"
+ "<a target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/custom_layers.ipynb\"><img width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a></td></table>"
]
},
{
diff --git a/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb
index 0a781d2153..591e2d0c85 100644
--- a/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb
+++ b/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb
@@ -75,10 +75,10 @@
"cell_type": "markdown",
"source": [
"<table class=\"tfo-notebook-buttons\" align=\"left\"><td>\n",
- "<a target=\"_blank\" href=\"https://colab.sandbox.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb\">\n",
- " <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /><span>Run in Google Colab</span></a>\n",
+ "<a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb\">\n",
+ " <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
"</td><td>\n",
- "<a target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb\"><img width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /><span>View source on GitHub</span></a></td></table>"
+ "<a target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb\"><img width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a></td></table>"
]
},
{
diff --git a/tensorflow/contrib/eager/python/examples/notebooks/eager_basics.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/eager_basics.ipynb
index b37a18c9a6..f1e13de5de 100644
--- a/tensorflow/contrib/eager/python/examples/notebooks/eager_basics.ipynb
+++ b/tensorflow/contrib/eager/python/examples/notebooks/eager_basics.ipynb
@@ -75,10 +75,10 @@
"cell_type": "markdown",
"source": [
"<table class=\"tfo-notebook-buttons\" align=\"left\"><td>\n",
- "<a target=\"_blank\" href=\"https://colab.sandbox.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/eager_basics.ipynb\">\n",
- " <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /><span>Run in Google Colab</span></a>\n",
+ "<a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/eager_basics.ipynb\">\n",
+ " <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
"</td><td>\n",
- "<a target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/eager_basics.ipynb\"><img width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /><span>View source on GitHub</span></a></td></table>"
+ "<a target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/eager_basics.ipynb\"><img width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a></td></table>"
]
},
{
diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py
index b14ef1df8f..07d8788882 100644
--- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py
+++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py
@@ -29,6 +29,7 @@ import tensorflow.contrib.eager as tfe
from tensorflow.contrib.eager.python.examples.resnet50 import resnet50
from tensorflow.contrib.summary import summary_test_util
from tensorflow.python.client import device_lib
+from tensorflow.python.eager import tape
def device_and_data_format():
@@ -49,13 +50,21 @@ def random_batch(batch_size, data_format):
return images, one_hot
-def compute_gradients(model, images, labels):
- with tf.GradientTape() as tape:
+def compute_gradients(model, images, labels, num_replicas=1):
+ with tf.GradientTape() as grad_tape:
logits = model(images, training=True)
loss = tf.losses.softmax_cross_entropy(
logits=logits, onehot_labels=labels)
tf.contrib.summary.scalar(name='loss', tensor=loss)
- return tape.gradient(loss, model.variables)
+ if num_replicas != 1:
+ loss /= num_replicas
+
+ # TODO(b/110991947): We can mistakenly trace the gradient call in
+ # multi-threaded environment. Explicitly disable recording until
+ # this is fixed.
+ with tape.stop_recording():
+ grads = grad_tape.gradient(loss, model.variables)
+ return grads
def apply_gradients(model, optimizer, gradients):
@@ -188,11 +197,14 @@ class ResNet50Benchmarks(tf.test.Benchmark):
return (32,)
return (16, 32)
- def _report(self, label, start, num_iters, device, batch_size, data_format):
+ def _report(self, label, start, num_iters, device, batch_size, data_format,
+ num_replicas=1):
avg_time = (time.time() - start) / num_iters
dev = tf.DeviceSpec.from_string(device).device_type.lower()
- name = '%s_%s_batch_%d_%s' % (label, dev, batch_size, data_format)
- extras = {'examples_per_sec': batch_size / avg_time}
+ replica_str = '' if num_replicas == 1 else 'replicas_%d_' % num_replicas
+ name = '%s_%s_batch_%d_%s%s' % (label, dev, batch_size,
+ replica_str, data_format)
+ extras = {'examples_per_sec': (num_replicas * batch_size) / avg_time}
self.report_benchmark(
iters=num_iters, wall_time=avg_time, name=name, extras=extras)
diff --git a/tensorflow/contrib/eager/python/examples/revnet/BUILD b/tensorflow/contrib/eager/python/examples/revnet/BUILD
index 81c9facfb5..0c0e4c0eb9 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/BUILD
+++ b/tensorflow/contrib/eager/python/examples/revnet/BUILD
@@ -78,7 +78,7 @@ cuda_py_test(
"//tensorflow:tensorflow_py",
],
tags = [
- "no_pip",
+ "no_pip", # depends on blocks_test, which is not available in pip package
"optonly",
],
)
diff --git a/tensorflow/contrib/eager/python/examples/workshop/1_basic.ipynb b/tensorflow/contrib/eager/python/examples/workshop/1_basic.ipynb
index 3e7abe952d..75cb3f8227 100644
--- a/tensorflow/contrib/eager/python/examples/workshop/1_basic.ipynb
+++ b/tensorflow/contrib/eager/python/examples/workshop/1_basic.ipynb
@@ -210,7 +210,7 @@
"a = tf.constant(0.0)\n",
"b = tf.constant(1.0)\n",
"epsilon = tf.constant(0.001)\n",
- "x = bisecting_line_search(test_f, a, b, epsilon)\n",
+ "x = bisecting_line_search(test_f, a, b, epsilon)\n"
],
"execution_count": 0,
"outputs": []
@@ -279,4 +279,4 @@
]
}
]
-} \ No newline at end of file
+}
diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD
index 30d297a5fb..11d40f5982 100644
--- a/tensorflow/contrib/estimator/BUILD
+++ b/tensorflow/contrib/estimator/BUILD
@@ -18,6 +18,7 @@ py_library(
":boosted_trees",
":dnn",
":dnn_linear_combined",
+ ":early_stopping",
":export",
":extenders",
":head",
@@ -590,3 +591,31 @@ py_test(
"@six_archive//:six",
],
)
+
+py_library(
+ name = "early_stopping",
+ srcs = ["python/estimator/early_stopping.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:init_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:summary",
+ "//tensorflow/python:training",
+ "//tensorflow/python/estimator",
+ ],
+)
+
+py_test(
+ name = "early_stopping_test",
+ srcs = ["python/estimator/early_stopping_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":early_stopping",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/estimator",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py
index 788ac5ca70..09fcfd66a1 100644
--- a/tensorflow/contrib/estimator/__init__.py
+++ b/tensorflow/contrib/estimator/__init__.py
@@ -23,6 +23,7 @@ from tensorflow.contrib.estimator.python.estimator.baseline import *
from tensorflow.contrib.estimator.python.estimator.boosted_trees import *
from tensorflow.contrib.estimator.python.estimator.dnn import *
from tensorflow.contrib.estimator.python.estimator.dnn_linear_combined import *
+from tensorflow.contrib.estimator.python.estimator.early_stopping import *
from tensorflow.contrib.estimator.python.estimator.export import *
from tensorflow.contrib.estimator.python.estimator.extenders import *
from tensorflow.contrib.estimator.python.estimator.head import *
@@ -63,6 +64,12 @@ _allowed_symbols = [
'RNNEstimator',
'export_saved_model_for_mode',
'export_all_saved_models',
+ 'make_early_stopping_hook',
+ 'read_eval_metrics',
+ 'stop_if_lower_hook',
+ 'stop_if_higher_hook',
+ 'stop_if_no_increase_hook',
+ 'stop_if_no_decrease_hook',
]
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/estimator/python/estimator/baseline_test.py b/tensorflow/contrib/estimator/python/estimator/baseline_test.py
index d0e3e670f7..505c94e971 100644
--- a/tensorflow/contrib/estimator/python/estimator/baseline_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/baseline_test.py
@@ -113,6 +113,8 @@ class BaselineEstimatorEvaluationTest(test.TestCase):
self.assertDictEqual({
metric_keys.MetricKeys.LOSS: 18.,
metric_keys.MetricKeys.LOSS_MEAN: 9.,
+ metric_keys.MetricKeys.PREDICTION_MEAN: 13.,
+ metric_keys.MetricKeys.LABEL_MEAN: 10.,
ops.GraphKeys.GLOBAL_STEP: 100
}, eval_metrics)
@@ -141,6 +143,8 @@ class BaselineEstimatorEvaluationTest(test.TestCase):
self.assertDictEqual({
metric_keys.MetricKeys.LOSS: 27.,
metric_keys.MetricKeys.LOSS_MEAN: 9.,
+ metric_keys.MetricKeys.PREDICTION_MEAN: 13.,
+ metric_keys.MetricKeys.LABEL_MEAN: 10.,
ops.GraphKeys.GLOBAL_STEP: 100
}, eval_metrics)
@@ -166,7 +170,9 @@ class BaselineEstimatorEvaluationTest(test.TestCase):
self.assertItemsEqual(
(metric_keys.MetricKeys.LOSS, metric_keys.MetricKeys.LOSS_MEAN,
- ops.GraphKeys.GLOBAL_STEP), eval_metrics.keys())
+ metric_keys.MetricKeys.PREDICTION_MEAN,
+ metric_keys.MetricKeys.LABEL_MEAN, ops.GraphKeys.GLOBAL_STEP),
+ eval_metrics.keys())
# Logit is bias which is [46, 58]
self.assertAlmostEqual(0, eval_metrics[metric_keys.MetricKeys.LOSS])
diff --git a/tensorflow/contrib/estimator/python/estimator/early_stopping.py b/tensorflow/contrib/estimator/python/estimator/early_stopping.py
new file mode 100644
index 0000000000..af4855e91e
--- /dev/null
+++ b/tensorflow/contrib/estimator/python/estimator/early_stopping.py
@@ -0,0 +1,468 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities for early stopping."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import operator
+import os
+
+from tensorflow.python.estimator import estimator as estimator_lib
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import tf_logging
+from tensorflow.python.summary import summary_iterator
+from tensorflow.python.training import basic_session_run_hooks
+from tensorflow.python.training import session_run_hook
+from tensorflow.python.training import training_util
+
+_EVENT_FILE_GLOB_PATTERN = 'events.out.tfevents.*'
+
+
+def make_early_stopping_hook(estimator,
+ should_stop_fn,
+ run_every_secs=60,
+ run_every_steps=None):
+ """Creates early-stopping hook.
+
+ Returns a `SessionRunHook` that stops training when `should_stop_fn` returns
+ `True`.
+
+ Usage example:
+
+ ```python
+ estimator = ...
+ hook = early_stopping.make_early_stopping_hook(
+ estimator, should_stop_fn=make_stop_fn(...))
+ train_spec = tf.estimator.TrainSpec(..., hooks=[hook])
+ tf.estimator.train_and_evaluate(estimator, train_spec, ...)
+ ```
+
+ Args:
+ estimator: A `tf.estimator.Estimator` instance.
+ should_stop_fn: `callable`, function that takes no arguments and returns a
+ `bool`. If the function returns `True`, stopping will be initiated by the
+ chief.
+ run_every_secs: If specified, calls `should_stop_fn` at an interval of
+ `run_every_secs` seconds. Defaults to 60 seconds. Either this or
+ `run_every_steps` must be set.
+ run_every_steps: If specified, calls `should_stop_fn` every
+ `run_every_steps` steps. Either this or `run_every_secs` must be set.
+
+ Returns:
+ A `SessionRunHook` that periodically executes `should_stop_fn` and initiates
+ early stopping if the function returns `True`.
+
+ Raises:
+ TypeError: If `estimator` is not of type `tf.estimator.Estimator`.
+ ValueError: If both `run_every_secs` and `run_every_steps` are set.
+ """
+ if not isinstance(estimator, estimator_lib.Estimator):
+ raise TypeError('`estimator` must have type `tf.estimator.Estimator`. '
+ 'Got: {}'.format(type(estimator)))
+
+ if run_every_secs is not None and run_every_steps is not None:
+ raise ValueError('Only one of `run_every_secs` and `run_every_steps` must '
+ 'be set.')
+
+ if estimator.config.is_chief:
+ return _StopOnPredicateHook(should_stop_fn, run_every_secs, run_every_steps)
+ else:
+ return _CheckForStoppingHook()
+
+
+def stop_if_higher_hook(estimator,
+ metric_name,
+ threshold,
+ eval_dir=None,
+ min_steps=0,
+ run_every_secs=60,
+ run_every_steps=None):
+ """Creates hook to stop if the given metric is higher than the threshold.
+
+ Usage example:
+
+ ```python
+ estimator = ...
+ # Hook to stop training if accuracy becomes higher than 0.9.
+ hook = early_stopping.stop_if_higher_hook(estimator, "accuracy", 0.9)
+ train_spec = tf.estimator.TrainSpec(..., hooks=[hook])
+ tf.estimator.train_and_evaluate(estimator, train_spec, ...)
+ ```
+
+ Args:
+ estimator: A `tf.estimator.Estimator` instance.
+ metric_name: `str`, metric to track. "loss", "accuracy", etc.
+ threshold: Numeric threshold for the given metric.
+ eval_dir: If set, directory containing summary files with eval metrics. By
+ default, `estimator.eval_dir()` will be used.
+ min_steps: `int`, stop is never requested if global step is less than this
+ value. Defaults to 0.
+ run_every_secs: If specified, calls `should_stop_fn` at an interval of
+ `run_every_secs` seconds. Defaults to 60 seconds. Either this or
+ `run_every_steps` must be set.
+ run_every_steps: If specified, calls `should_stop_fn` every
+ `run_every_steps` steps. Either this or `run_every_secs` must be set.
+
+ Returns:
+ An early-stopping hook of type `SessionRunHook` that periodically checks
+ if the given metric is higher than specified threshold and initiates
+ early stopping if true.
+ """
+ return _stop_if_threshold_crossed_hook(
+ estimator=estimator,
+ metric_name=metric_name,
+ threshold=threshold,
+ higher_is_better=True,
+ eval_dir=eval_dir,
+ min_steps=min_steps,
+ run_every_secs=run_every_secs,
+ run_every_steps=run_every_steps)
+
+
+def stop_if_lower_hook(estimator,
+ metric_name,
+ threshold,
+ eval_dir=None,
+ min_steps=0,
+ run_every_secs=60,
+ run_every_steps=None):
+ """Creates hook to stop if the given metric is lower than the threshold.
+
+ Usage example:
+
+ ```python
+ estimator = ...
+ # Hook to stop training if loss becomes lower than 100.
+ hook = early_stopping.stop_if_lower_hook(estimator, "loss", 100)
+ train_spec = tf.estimator.TrainSpec(..., hooks=[hook])
+ tf.estimator.train_and_evaluate(estimator, train_spec, ...)
+ ```
+
+ Args:
+ estimator: A `tf.estimator.Estimator` instance.
+ metric_name: `str`, metric to track. "loss", "accuracy", etc.
+ threshold: Numeric threshold for the given metric.
+ eval_dir: If set, directory containing summary files with eval metrics. By
+ default, `estimator.eval_dir()` will be used.
+ min_steps: `int`, stop is never requested if global step is less than this
+ value. Defaults to 0.
+ run_every_secs: If specified, calls `should_stop_fn` at an interval of
+ `run_every_secs` seconds. Defaults to 60 seconds. Either this or
+ `run_every_steps` must be set.
+ run_every_steps: If specified, calls `should_stop_fn` every
+ `run_every_steps` steps. Either this or `run_every_secs` must be set.
+
+ Returns:
+ An early-stopping hook of type `SessionRunHook` that periodically checks
+ if the given metric is lower than specified threshold and initiates
+ early stopping if true.
+ """
+ return _stop_if_threshold_crossed_hook(
+ estimator=estimator,
+ metric_name=metric_name,
+ threshold=threshold,
+ higher_is_better=False,
+ eval_dir=eval_dir,
+ min_steps=min_steps,
+ run_every_secs=run_every_secs,
+ run_every_steps=run_every_steps)
+
+
+def stop_if_no_increase_hook(estimator,
+ metric_name,
+ max_steps_without_increase,
+ eval_dir=None,
+ min_steps=0,
+ run_every_secs=60,
+ run_every_steps=None):
+ """Creates hook to stop if metric does not increase within given max steps.
+
+ Usage example:
+
+ ```python
+ estimator = ...
+ # Hook to stop training if accuracy does not increase in over 100000 steps.
+ hook = early_stopping.stop_if_no_increase_hook(estimator, "accuracy", 100000)
+ train_spec = tf.estimator.TrainSpec(..., hooks=[hook])
+ tf.estimator.train_and_evaluate(estimator, train_spec, ...)
+ ```
+
+ Args:
+ estimator: A `tf.estimator.Estimator` instance.
+ metric_name: `str`, metric to track. "loss", "accuracy", etc.
+ max_steps_without_increase: `int`, maximum number of training steps with no
+ increase in the given metric.
+ eval_dir: If set, directory containing summary files with eval metrics. By
+ default, `estimator.eval_dir()` will be used.
+ min_steps: `int`, stop is never requested if global step is less than this
+ value. Defaults to 0.
+ run_every_secs: If specified, calls `should_stop_fn` at an interval of
+ `run_every_secs` seconds. Defaults to 60 seconds. Either this or
+ `run_every_steps` must be set.
+ run_every_steps: If specified, calls `should_stop_fn` every
+ `run_every_steps` steps. Either this or `run_every_secs` must be set.
+
+ Returns:
+ An early-stopping hook of type `SessionRunHook` that periodically checks
+ if the given metric shows no increase over given maximum number of
+ training steps, and initiates early stopping if true.
+ """
+ return _stop_if_no_metric_improvement_hook(
+ estimator=estimator,
+ metric_name=metric_name,
+ max_steps_without_improvement=max_steps_without_increase,
+ higher_is_better=True,
+ eval_dir=eval_dir,
+ min_steps=min_steps,
+ run_every_secs=run_every_secs,
+ run_every_steps=run_every_steps)
+
+
+def stop_if_no_decrease_hook(estimator,
+ metric_name,
+ max_steps_without_decrease,
+ eval_dir=None,
+ min_steps=0,
+ run_every_secs=60,
+ run_every_steps=None):
+ """Creates hook to stop if metric does not decrease within given max steps.
+
+ Usage example:
+
+ ```python
+ estimator = ...
+ # Hook to stop training if loss does not decrease in over 100000 steps.
+ hook = early_stopping.stop_if_no_decrease_hook(estimator, "loss", 100000)
+ train_spec = tf.estimator.TrainSpec(..., hooks=[hook])
+ tf.estimator.train_and_evaluate(estimator, train_spec, ...)
+ ```
+
+ Args:
+ estimator: A `tf.estimator.Estimator` instance.
+ metric_name: `str`, metric to track. "loss", "accuracy", etc.
+ max_steps_without_decrease: `int`, maximum number of training steps with no
+ decrease in the given metric.
+ eval_dir: If set, directory containing summary files with eval metrics. By
+ default, `estimator.eval_dir()` will be used.
+ min_steps: `int`, stop is never requested if global step is less than this
+ value. Defaults to 0.
+ run_every_secs: If specified, calls `should_stop_fn` at an interval of
+ `run_every_secs` seconds. Defaults to 60 seconds. Either this or
+ `run_every_steps` must be set.
+ run_every_steps: If specified, calls `should_stop_fn` every
+ `run_every_steps` steps. Either this or `run_every_secs` must be set.
+
+ Returns:
+ An early-stopping hook of type `SessionRunHook` that periodically checks
+ if the given metric shows no decrease over given maximum number of
+ training steps, and initiates early stopping if true.
+ """
+ return _stop_if_no_metric_improvement_hook(
+ estimator=estimator,
+ metric_name=metric_name,
+ max_steps_without_improvement=max_steps_without_decrease,
+ higher_is_better=False,
+ eval_dir=eval_dir,
+ min_steps=min_steps,
+ run_every_secs=run_every_secs,
+ run_every_steps=run_every_steps)
+
+
+def read_eval_metrics(eval_dir):
+ """Helper to read eval metrics from eval summary files.
+
+ Args:
+ eval_dir: Directory containing summary files with eval metrics.
+
+ Returns:
+ A `dict` with global steps mapping to `dict` of metric names and values.
+ """
+ eval_metrics_dict = {}
+ for event in _summaries(eval_dir):
+ if not event.HasField('summary'):
+ continue
+ metrics = {}
+ for value in event.summary.value:
+ if value.HasField('simple_value'):
+ metrics[value.tag] = value.simple_value
+ if metrics:
+ eval_metrics_dict[event.step] = metrics
+ return eval_metrics_dict
+
+
+def _stop_if_threshold_crossed_hook(estimator, metric_name, threshold,
+ higher_is_better, eval_dir, min_steps,
+ run_every_secs, run_every_steps):
+ """Creates early-stopping hook to stop training if threshold is crossed."""
+
+ if eval_dir is None:
+ eval_dir = estimator.eval_dir()
+
+ is_lhs_better = operator.gt if higher_is_better else operator.lt
+ greater_or_lesser = 'greater than' if higher_is_better else 'less than'
+
+ def stop_if_threshold_crossed_fn():
+ """Returns `True` if the given metric crosses specified threshold."""
+
+ eval_results = read_eval_metrics(eval_dir)
+
+ for step, metrics in eval_results.items():
+ if step < min_steps:
+ continue
+ val = metrics[metric_name]
+ if is_lhs_better(val, threshold):
+ tf_logging.info(
+ 'At step %s, metric "%s" has value %s which is %s the configured '
+ 'threshold (%s) for early stopping.', step, metric_name, val,
+ greater_or_lesser, threshold)
+ return True
+ return False
+
+ return make_early_stopping_hook(
+ estimator=estimator,
+ should_stop_fn=stop_if_threshold_crossed_fn,
+ run_every_secs=run_every_secs,
+ run_every_steps=run_every_steps)
+
+
+def _stop_if_no_metric_improvement_hook(
+ estimator, metric_name, max_steps_without_improvement, higher_is_better,
+ eval_dir, min_steps, run_every_secs, run_every_steps):
+ """Returns hook to stop training if given metric shows no improvement."""
+
+ if eval_dir is None:
+ eval_dir = estimator.eval_dir()
+
+ is_lhs_better = operator.gt if higher_is_better else operator.lt
+ increase_or_decrease = 'increase' if higher_is_better else 'decrease'
+
+ def stop_if_no_metric_improvement_fn():
+ """Returns `True` if metric does not improve within max steps."""
+
+ eval_results = read_eval_metrics(eval_dir)
+
+ best_val = None
+ best_val_step = None
+ for step, metrics in eval_results.items():
+ if step < min_steps:
+ continue
+ val = metrics[metric_name]
+ if best_val is None or is_lhs_better(val, best_val):
+ best_val = val
+ best_val_step = step
+ if step - best_val_step >= max_steps_without_improvement:
+ tf_logging.info(
+ 'No %s in metric "%s" for %s steps, which is greater than or equal '
+ 'to max steps (%s) configured for early stopping.',
+ increase_or_decrease, metric_name, step - best_val_step,
+ max_steps_without_improvement)
+ return True
+ return False
+
+ return make_early_stopping_hook(
+ estimator=estimator,
+ should_stop_fn=stop_if_no_metric_improvement_fn,
+ run_every_secs=run_every_secs,
+ run_every_steps=run_every_steps)
+
+
+def _summaries(eval_dir):
+ """Yields `tensorflow.Event` protos from event files in the eval dir.
+
+ Args:
+ eval_dir: Directory containing summary files with eval metrics.
+
+ Yields:
+ `tensorflow.Event` object read from the event files.
+ """
+ for event_file in gfile.Glob(
+ os.path.join(eval_dir, _EVENT_FILE_GLOB_PATTERN)):
+ for event in summary_iterator.summary_iterator(event_file):
+ yield event
+
+
+def _get_or_create_stop_var():
+ with variable_scope.variable_scope(
+ name_or_scope='signal_early_stopping',
+ values=[],
+ reuse=variable_scope.AUTO_REUSE):
+ return variable_scope.get_variable(
+ name='STOP',
+ shape=[],
+ dtype=dtypes.bool,
+ initializer=init_ops.constant_initializer(False),
+ collections=[ops.GraphKeys.GLOBAL_VARIABLES],
+ trainable=False)
+
+
+class _StopOnPredicateHook(session_run_hook.SessionRunHook):
+ """Hook that requests stop when `should_stop_fn` returns `True`."""
+
+ def __init__(self, should_stop_fn, run_every_secs=60, run_every_steps=None):
+ if not callable(should_stop_fn):
+ raise TypeError('`should_stop_fn` must be callable.')
+
+ self._should_stop_fn = should_stop_fn
+ self._timer = basic_session_run_hooks.SecondOrStepTimer(
+ every_secs=run_every_secs, every_steps=run_every_steps)
+ self._global_step_tensor = None
+ self._stop_var = None
+ self._stop_op = None
+
+ def begin(self):
+ self._global_step_tensor = training_util.get_global_step()
+ self._stop_var = _get_or_create_stop_var()
+ self._stop_op = state_ops.assign(self._stop_var, True)
+
+ def before_run(self, run_context):
+ del run_context
+ return session_run_hook.SessionRunArgs(self._global_step_tensor)
+
+ def after_run(self, run_context, run_values):
+ global_step = run_values.results
+ if self._timer.should_trigger_for_step(global_step):
+ self._timer.update_last_triggered_step(global_step)
+ if self._should_stop_fn():
+ tf_logging.info('Requesting early stopping at global step %d',
+ global_step)
+ run_context.session.run(self._stop_op)
+ run_context.request_stop()
+
+
+class _CheckForStoppingHook(session_run_hook.SessionRunHook):
+ """Hook that requests stop if stop is requested by `_StopOnPredicateHook`."""
+
+ def __init__(self):
+ self._stop_var = None
+
+ def begin(self):
+ self._stop_var = _get_or_create_stop_var()
+
+ def before_run(self, run_context):
+ del run_context
+ return session_run_hook.SessionRunArgs(self._stop_var)
+
+ def after_run(self, run_context, run_values):
+ should_early_stop = run_values.results
+ if should_early_stop:
+ tf_logging.info('Early stopping requested, suspending run.')
+ run_context.request_stop()
diff --git a/tensorflow/contrib/estimator/python/estimator/early_stopping_test.py b/tensorflow/contrib/estimator/python/estimator/early_stopping_test.py
new file mode 100644
index 0000000000..b5eee818fa
--- /dev/null
+++ b/tensorflow/contrib/estimator/python/estimator/early_stopping_test.py
@@ -0,0 +1,233 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for early_stopping."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import tempfile
+
+from absl.testing import parameterized
+from tensorflow.contrib.estimator.python.estimator import early_stopping
+from tensorflow.python.estimator import estimator
+from tensorflow.python.estimator import run_config
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.platform import test
+from tensorflow.python.training import monitored_session
+from tensorflow.python.training import training_util
+
+
+class _FakeRunConfig(run_config.RunConfig):
+
+ def __init__(self, is_chief):
+ super(_FakeRunConfig, self).__init__()
+ self._is_chief = is_chief
+
+ @property
+ def is_chief(self):
+ return self._is_chief
+
+
+def _dummy_model_fn(features, labels, params):
+ _, _, _ = features, labels, params
+
+
+class _FakeEstimator(estimator.Estimator):
+ """Fake estimator for testing."""
+
+ def __init__(self, config):
+ super(_FakeEstimator, self).__init__(
+ model_fn=_dummy_model_fn, config=config)
+
+
+def _write_events(eval_dir, params):
+ """Test helper to write events to summary files."""
+ for steps, loss, accuracy in params:
+ estimator._write_dict_to_summary(eval_dir, {
+ 'loss': loss,
+ 'accuracy': accuracy,
+ }, steps)
+
+
+class ReadEvalMetricsTest(test.TestCase):
+
+ def test_read_eval_metrics(self):
+ eval_dir = tempfile.mkdtemp()
+ _write_events(
+ eval_dir,
+ [
+ # steps, loss, accuracy
+ (1000, 1, 2),
+ (2000, 3, 4),
+ (3000, 5, 6),
+ ])
+ self.assertEqual({
+ 1000: {
+ 'loss': 1,
+ 'accuracy': 2
+ },
+ 2000: {
+ 'loss': 3,
+ 'accuracy': 4
+ },
+ 3000: {
+ 'loss': 5,
+ 'accuracy': 6
+ },
+ }, early_stopping.read_eval_metrics(eval_dir))
+
+
+class EarlyStoppingHooksTest(test.TestCase, parameterized.TestCase):
+
+ def setUp(self):
+ config = _FakeRunConfig(is_chief=True)
+ self._estimator = _FakeEstimator(config=config)
+ eval_dir = self._estimator.eval_dir()
+ os.makedirs(eval_dir)
+ _write_events(
+ eval_dir,
+ [
+ # steps, loss, accuracy
+ (1000, 0.8, 0.5),
+ (2000, 0.7, 0.6),
+ (3000, 0.4, 0.7),
+ (3500, 0.41, 0.68),
+ ])
+
+ def run_session(self, hooks, should_stop):
+ hooks = hooks if isinstance(hooks, list) else [hooks]
+ with ops.Graph().as_default():
+ training_util.create_global_step()
+ no_op = control_flow_ops.no_op()
+ with monitored_session.SingularMonitoredSession(hooks=hooks) as mon_sess:
+ mon_sess.run(no_op)
+ self.assertEqual(mon_sess.should_stop(), should_stop)
+
+ @parameterized.parameters((0.8, 0, False), (0.6, 4000, False), (0.6, 0, True))
+ def test_stop_if_higher_hook(self, threshold, min_steps, should_stop):
+ self.run_session(
+ early_stopping.stop_if_higher_hook(
+ self._estimator,
+ metric_name='accuracy',
+ threshold=threshold,
+ min_steps=min_steps), should_stop)
+
+ @parameterized.parameters((0.3, 0, False), (0.5, 4000, False), (0.5, 0, True))
+ def test_stop_if_lower_hook(self, threshold, min_steps, should_stop):
+ self.run_session(
+ early_stopping.stop_if_lower_hook(
+ self._estimator,
+ metric_name='loss',
+ threshold=threshold,
+ min_steps=min_steps), should_stop)
+
+ @parameterized.parameters((1500, 0, False), (500, 4000, False),
+ (500, 0, True))
+ def test_stop_if_no_increase_hook(self, max_steps, min_steps, should_stop):
+ self.run_session(
+ early_stopping.stop_if_no_increase_hook(
+ self._estimator,
+ metric_name='accuracy',
+ max_steps_without_increase=max_steps,
+ min_steps=min_steps), should_stop)
+
+ @parameterized.parameters((1500, 0, False), (500, 4000, False),
+ (500, 0, True))
+ def test_stop_if_no_decrease_hook(self, max_steps, min_steps, should_stop):
+ self.run_session(
+ early_stopping.stop_if_no_decrease_hook(
+ self._estimator,
+ metric_name='loss',
+ max_steps_without_decrease=max_steps,
+ min_steps=min_steps), should_stop)
+
+ @parameterized.parameters((1500, 0.3, False), (1500, 0.5, True),
+ (500, 0.3, True))
+ def test_multiple_hooks(self, max_steps, loss_threshold, should_stop):
+ self.run_session([
+ early_stopping.stop_if_no_decrease_hook(
+ self._estimator,
+ metric_name='loss',
+ max_steps_without_decrease=max_steps),
+ early_stopping.stop_if_lower_hook(
+ self._estimator, metric_name='loss', threshold=loss_threshold)
+ ], should_stop)
+
+ @parameterized.parameters(False, True)
+ def test_make_early_stopping_hook(self, should_stop):
+ self.run_session([
+ early_stopping.make_early_stopping_hook(
+ self._estimator, should_stop_fn=lambda: should_stop)
+ ], should_stop)
+
+ def test_make_early_stopping_hook_typeerror(self):
+ with self.assertRaises(TypeError):
+ early_stopping.make_early_stopping_hook(
+ estimator=object(), should_stop_fn=lambda: True)
+
+ def test_make_early_stopping_hook_valueerror(self):
+ with self.assertRaises(ValueError):
+ early_stopping.make_early_stopping_hook(
+ self._estimator,
+ should_stop_fn=lambda: True,
+ run_every_secs=60,
+ run_every_steps=100)
+
+
+class StopOnPredicateHookTest(test.TestCase):
+
+ def test_stop(self):
+ hook = early_stopping._StopOnPredicateHook(
+ should_stop_fn=lambda: False, run_every_secs=0)
+ with ops.Graph().as_default():
+ training_util.create_global_step()
+ no_op = control_flow_ops.no_op()
+ with monitored_session.SingularMonitoredSession(hooks=[hook]) as mon_sess:
+ mon_sess.run(no_op)
+ self.assertFalse(mon_sess.should_stop())
+ self.assertFalse(mon_sess.raw_session().run(hook._stop_var))
+
+ hook = early_stopping._StopOnPredicateHook(
+ should_stop_fn=lambda: True, run_every_secs=0)
+ with ops.Graph().as_default():
+ training_util.create_global_step()
+ no_op = control_flow_ops.no_op()
+ with monitored_session.SingularMonitoredSession(hooks=[hook]) as mon_sess:
+ mon_sess.run(no_op)
+ self.assertTrue(mon_sess.should_stop())
+ self.assertTrue(mon_sess.raw_session().run(hook._stop_var))
+
+
+class CheckForStoppingHookTest(test.TestCase):
+
+ def test_stop(self):
+ hook = early_stopping._CheckForStoppingHook()
+ with ops.Graph().as_default():
+ no_op = control_flow_ops.no_op()
+ assign_op = state_ops.assign(early_stopping._get_or_create_stop_var(),
+ True)
+ with monitored_session.SingularMonitoredSession(hooks=[hook]) as mon_sess:
+ mon_sess.run(no_op)
+ self.assertFalse(mon_sess.should_stop())
+ mon_sess.run(assign_op)
+ self.assertTrue(mon_sess.should_stop())
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD
index b305f37791..10a8796bcb 100644
--- a/tensorflow/contrib/gan/BUILD
+++ b/tensorflow/contrib/gan/BUILD
@@ -45,6 +45,7 @@ py_library(
"//tensorflow/python:framework_ops",
"//tensorflow/python:init_ops",
"//tensorflow/python:training",
+ "//tensorflow/python:training_util",
"//tensorflow/python:variable_scope",
"//tensorflow/python/ops/distributions",
"//tensorflow/python/ops/losses",
@@ -59,6 +60,7 @@ py_test(
deps = [
":features",
":namedtuples",
+ ":random_tensor_pool",
":train",
"//tensorflow/contrib/framework:framework_py",
"//tensorflow/contrib/slim:learning",
@@ -70,6 +72,7 @@ py_test(
"//tensorflow/python:random_ops",
"//tensorflow/python:random_seed",
"//tensorflow/python:training",
+ "//tensorflow/python:training_util",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
"//tensorflow/python/ops/distributions",
@@ -188,6 +191,7 @@ py_test(
srcs = ["python/losses/python/tuple_losses_test.py"],
srcs_version = "PY2AND3",
deps = [
+ ":namedtuples",
":tuple_losses",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
@@ -344,9 +348,11 @@ py_library(
"//tensorflow/python:image_ops",
"//tensorflow/python:linalg_ops",
"//tensorflow/python:math_ops",
+ "//tensorflow/python:nn",
"//tensorflow/python:nn_ops",
"//tensorflow/python:platform",
"//tensorflow/python:util",
+ "@six_archive//:six",
],
)
@@ -470,12 +476,12 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- ":head",
":namedtuples",
":summaries",
":train",
"//tensorflow/contrib/framework:framework_py",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:metrics",
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python/estimator",
@@ -498,16 +504,19 @@ py_test(
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:metrics",
"//tensorflow/python:parsing_ops",
"//tensorflow/python:summary",
"//tensorflow/python:training",
- "//tensorflow/python/estimator:head",
+ "//tensorflow/python:training_util",
+ "//tensorflow/python:variable_scope",
"//tensorflow/python/estimator:model_fn",
"//tensorflow/python/estimator:numpy_io",
"//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
"@six_archive//:six",
],
)
diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py
index 4092b32004..8e4affb9b4 100644
--- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py
+++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py
@@ -24,11 +24,11 @@ import enum
from tensorflow.contrib.framework.python.ops import variables as variable_lib
from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples
from tensorflow.contrib.gan.python import train as tfgan_train
-from tensorflow.contrib.gan.python.estimator.python import head as head_lib
from tensorflow.contrib.gan.python.eval.python import summaries as tfgan_summaries
from tensorflow.python.estimator import estimator
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.framework import ops
+from tensorflow.python.ops import metrics as metrics_lib
from tensorflow.python.ops import variable_scope
from tensorflow.python.util import tf_inspect as inspect
@@ -154,94 +154,93 @@ class GANEstimator(estimator.Estimator):
use_loss_summaries: If `True`, add loss summaries. If `False`, does not.
If `None`, uses defaults.
config: `RunConfig` object to configure the runtime settings.
+
+ Raises:
+ ValueError: If loss functions aren't callable.
+ ValueError: If `use_loss_summaries` isn't boolean or `None`.
+ ValueError: If `get_hooks_fn` isn't callable or `None`.
"""
- # TODO(joelshor): Explicitly validate inputs.
+ if not callable(generator_loss_fn):
+ raise ValueError('generator_loss_fn must be callable.')
+ if not callable(discriminator_loss_fn):
+ raise ValueError('discriminator_loss_fn must be callable.')
+ if use_loss_summaries not in [True, False, None]:
+ raise ValueError('use_loss_summaries must be True, False or None.')
+ if get_hooks_fn is not None and not callable(get_hooks_fn):
+ raise TypeError('get_hooks_fn must be callable.')
def _model_fn(features, labels, mode):
- gopt = (generator_optimizer() if callable(generator_optimizer) else
- generator_optimizer)
- dopt = (discriminator_optimizer() if callable(discriminator_optimizer)
- else discriminator_optimizer)
- gan_head = head_lib.gan_head(
- generator_loss_fn, discriminator_loss_fn, gopt, dopt,
- use_loss_summaries, get_hooks_fn=get_hooks_fn,
- get_eval_metric_ops_fn=get_eval_metric_ops_fn)
- return _gan_model_fn(
- features, labels, mode, generator_fn, discriminator_fn, gan_head,
+ """GANEstimator model function."""
+ if mode not in [model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL,
+ model_fn_lib.ModeKeys.PREDICT]:
+ raise ValueError('Mode not recognized: %s' % mode)
+ real_data = labels # rename inputs for clarity
+ generator_inputs = features # rename inputs for clarity
+
+ # Make GANModel, which encapsulates the GAN model architectures.
+ gan_model = _get_gan_model(
+ mode, generator_fn, discriminator_fn, real_data, generator_inputs,
add_summaries)
+ # Make the EstimatorSpec, which incorporates the GANModel, losses, eval
+ # metrics, and optimizers (if required).
+ return _get_estimator_spec(
+ mode, gan_model, generator_loss_fn, discriminator_loss_fn,
+ get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer,
+ get_hooks_fn)
+
super(GANEstimator, self).__init__(
model_fn=_model_fn, model_dir=model_dir, config=config)
-def _gan_model_fn(
- features,
- labels,
- mode,
- generator_fn,
- discriminator_fn,
- head,
- add_summaries=None,
- generator_scope_name='Generator'):
- """The `model_fn` for the GAN estimator.
-
- We make the following convention:
- features -> TFGAN's `generator_inputs`
- labels -> TFGAN's `real_data`
-
- Args:
- features: A dictionary to feed to generator. In the unconditional case,
- this might be just `noise`. In the conditional GAN case, this
- might be the generator's conditioning. The `generator_fn` determines
- what the required keys are.
- labels: Real data. Can be any structure, as long as `discriminator_fn`
- can accept it for the first argument.
- mode: Defines whether this is training, evaluation or prediction.
- See `ModeKeys`.
- generator_fn: A python lambda that takes `generator_inputs` as inputs and
- returns the outputs of the GAN generator.
- discriminator_fn: A python lambda that takes `real_data`/`generated data`
- and `generator_inputs`. Outputs a Tensor in the range [-inf, inf].
- head: A `Head` instance suitable for GANs.
- add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`.
- generator_scope_name: The name of the generator scope. We need this to be
- the same for GANModels produced by TFGAN's `train.gan_model` and the
- manually constructed ones for predictions.
-
- Returns:
- `ModelFnOps`
-
- Raises:
- ValueError: If `labels` isn't `None` during prediction.
- """
- real_data = labels
- generator_inputs = features
-
- if mode == model_fn_lib.ModeKeys.TRAIN:
- gan_model = _make_train_gan_model(
- generator_fn, discriminator_fn, real_data, generator_inputs,
- generator_scope_name, add_summaries)
- elif mode == model_fn_lib.ModeKeys.EVAL:
- gan_model = _make_eval_gan_model(
- generator_fn, discriminator_fn, real_data, generator_inputs,
- generator_scope_name, add_summaries)
- else:
+def _get_gan_model(
+ mode, generator_fn, discriminator_fn, real_data, generator_inputs,
+ add_summaries, generator_scope='Generator'):
+ """Makes the GANModel tuple, which encapsulates the GAN model architecture."""
+ if mode == model_fn_lib.ModeKeys.PREDICT:
if real_data is not None:
raise ValueError('`labels` must be `None` when mode is `predict`. '
'Instead, found %s' % real_data)
gan_model = _make_prediction_gan_model(
- generator_inputs, generator_fn, generator_scope_name)
+ generator_inputs, generator_fn, generator_scope)
+ else: # model_fn_lib.ModeKeys.TRAIN or model_fn_lib.ModeKeys.EVAL
+ gan_model = _make_gan_model(
+ generator_fn, discriminator_fn, real_data, generator_inputs,
+ generator_scope, add_summaries, mode)
- return head.create_estimator_spec(
- features=None,
- mode=mode,
- logits=gan_model,
- labels=None)
+ return gan_model
+
+
+def _get_estimator_spec(
+ mode, gan_model, generator_loss_fn, discriminator_loss_fn,
+ get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer,
+ get_hooks_fn=None):
+ """Get the EstimatorSpec for the current mode."""
+ if mode == model_fn_lib.ModeKeys.PREDICT:
+ estimator_spec = model_fn_lib.EstimatorSpec(
+ mode=mode, predictions=gan_model.generated_data)
+ else:
+ gan_loss = tfgan_tuples.GANLoss(
+ generator_loss=generator_loss_fn(gan_model),
+ discriminator_loss=discriminator_loss_fn(gan_model))
+ if mode == model_fn_lib.ModeKeys.EVAL:
+ estimator_spec = _get_eval_estimator_spec(
+ gan_model, gan_loss, get_eval_metric_ops_fn)
+ else: # model_fn_lib.ModeKeys.TRAIN:
+ gopt = (generator_optimizer() if callable(generator_optimizer) else
+ generator_optimizer)
+ dopt = (discriminator_optimizer() if callable(discriminator_optimizer)
+ else discriminator_optimizer)
+ get_hooks_fn = get_hooks_fn or tfgan_train.get_sequential_train_hooks()
+ estimator_spec = _get_train_estimator_spec(
+ gan_model, gan_loss, gopt, dopt, get_hooks_fn)
+
+ return estimator_spec
def _make_gan_model(generator_fn, discriminator_fn, real_data,
generator_inputs, generator_scope, add_summaries, mode):
- """Make a `GANModel`, and optionally pass in `mode`."""
+ """Construct a `GANModel`, and optionally pass in `mode`."""
# If network functions have an argument `mode`, pass mode to it.
if 'mode' in inspect.getargspec(generator_fn).args:
generator_fn = functools.partial(generator_fn, mode=mode)
@@ -264,22 +263,6 @@ def _make_gan_model(generator_fn, discriminator_fn, real_data,
return gan_model
-def _make_train_gan_model(generator_fn, discriminator_fn, real_data,
- generator_inputs, generator_scope, add_summaries):
- """Make a `GANModel` for training."""
- return _make_gan_model(generator_fn, discriminator_fn, real_data,
- generator_inputs, generator_scope, add_summaries,
- model_fn_lib.ModeKeys.TRAIN)
-
-
-def _make_eval_gan_model(generator_fn, discriminator_fn, real_data,
- generator_inputs, generator_scope, add_summaries):
- """Make a `GANModel` for evaluation."""
- return _make_gan_model(generator_fn, discriminator_fn, real_data,
- generator_inputs, generator_scope, add_summaries,
- model_fn_lib.ModeKeys.EVAL)
-
-
def _make_prediction_gan_model(generator_inputs, generator_fn, generator_scope):
"""Make a `GANModel` from just the generator."""
# If `generator_fn` has an argument `mode`, pass mode to it.
@@ -303,3 +286,46 @@ def _make_prediction_gan_model(generator_inputs, generator_fn, generator_scope):
discriminator_variables=None,
discriminator_scope=None,
discriminator_fn=None)
+
+
+def _get_eval_estimator_spec(gan_model, gan_loss, get_eval_metric_ops_fn=None,
+ name=None):
+ """Return an EstimatorSpec for the eval case."""
+ scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss
+ with ops.name_scope(None, 'metrics',
+ [gan_loss.generator_loss,
+ gan_loss.discriminator_loss]):
+ def _summary_key(head_name, val):
+ return '%s/%s' % (val, head_name) if head_name else val
+ eval_metric_ops = {
+ _summary_key(name, 'generator_loss'):
+ metrics_lib.mean(gan_loss.generator_loss),
+ _summary_key(name, 'discriminator_loss'):
+ metrics_lib.mean(gan_loss.discriminator_loss)
+ }
+ if get_eval_metric_ops_fn is not None:
+ custom_eval_metric_ops = get_eval_metric_ops_fn(gan_model)
+ if not isinstance(custom_eval_metric_ops, dict):
+ raise TypeError('get_eval_metric_ops_fn must return a dict, '
+ 'received: {}'.format(custom_eval_metric_ops))
+ eval_metric_ops.update(custom_eval_metric_ops)
+ return model_fn_lib.EstimatorSpec(
+ mode=model_fn_lib.ModeKeys.EVAL,
+ predictions=gan_model.generated_data,
+ loss=scalar_loss,
+ eval_metric_ops=eval_metric_ops)
+
+
+def _get_train_estimator_spec(
+ gan_model, gan_loss, generator_optimizer, discriminator_optimizer,
+ get_hooks_fn, train_op_fn=tfgan_train.gan_train_ops):
+ """Return an EstimatorSpec for the train case."""
+ scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss
+ train_ops = train_op_fn(gan_model, gan_loss, generator_optimizer,
+ discriminator_optimizer)
+ training_hooks = get_hooks_fn(train_ops)
+ return model_fn_lib.EstimatorSpec(
+ loss=scalar_loss,
+ mode=model_fn_lib.ModeKeys.TRAIN,
+ train_op=train_ops.global_step_inc_op,
+ training_hooks=training_hooks)
diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py
index 955482599b..9ac9c6ca9c 100644
--- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py
+++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py
@@ -21,30 +21,30 @@ from __future__ import print_function
import shutil
import tempfile
+from absl.testing import parameterized
import numpy as np
import six
from tensorflow.contrib import layers
-from tensorflow.contrib.gan.python import namedtuples
+from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples
from tensorflow.contrib.gan.python.estimator.python import gan_estimator_impl as estimator
from tensorflow.contrib.gan.python.losses.python import tuple_losses as losses
from tensorflow.contrib.learn.python.learn.learn_io import graph_io
from tensorflow.core.example import example_pb2
from tensorflow.core.example import feature_pb2
from tensorflow.python.estimator import model_fn as model_fn_lib
-from tensorflow.python.estimator.canned import head as head_lib
from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics as metrics_lib
from tensorflow.python.ops import parsing_ops
+from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
from tensorflow.python.summary.writer import writer_cache
from tensorflow.python.training import input as input_lib
from tensorflow.python.training import learning_rate_decay
-from tensorflow.python.training import monitored_session
from tensorflow.python.training import training
from tensorflow.python.training import training_util
@@ -60,120 +60,109 @@ def discriminator_fn(data, unused_conditioning, mode):
return layers.fully_connected(data, 1)
-def mock_head(testcase, expected_generator_inputs, expected_real_data,
- generator_scope_name):
- """Returns a mock head that validates logits values and variable names."""
- discriminator_scope_name = 'Discriminator' # comes from TFGAN defaults
- generator_var_names = set([
- '%s/fully_connected/weights:0' % generator_scope_name,
- '%s/fully_connected/biases:0' % generator_scope_name])
- discriminator_var_names = set([
- '%s/fully_connected/weights:0' % discriminator_scope_name,
- '%s/fully_connected/biases:0' % discriminator_scope_name])
-
- def _create_estimator_spec(features, mode, logits, labels):
- gan_model = logits # renaming for clarity
- is_predict = mode == model_fn_lib.ModeKeys.PREDICT
- testcase.assertIsNone(features)
- testcase.assertIsNone(labels)
- testcase.assertIsInstance(gan_model, namedtuples.GANModel)
-
- trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- expected_var_names = (generator_var_names if is_predict else
- generator_var_names | discriminator_var_names)
- testcase.assertItemsEqual(expected_var_names,
- [var.name for var in trainable_vars])
-
- assertions = []
- def _or_none(x):
- return None if is_predict else x
- testcase.assertEqual(expected_generator_inputs, gan_model.generator_inputs)
- # TODO(joelshor): Add check on `generated_data`.
- testcase.assertItemsEqual(
- generator_var_names,
- set([x.name for x in gan_model.generator_variables]))
- testcase.assertEqual(generator_scope_name, gan_model.generator_scope.name)
- testcase.assertEqual(_or_none(expected_real_data), gan_model.real_data)
- # TODO(joelshor): Add check on `discriminator_real_outputs`.
- # TODO(joelshor): Add check on `discriminator_gen_outputs`.
- if is_predict:
- testcase.assertIsNone(gan_model.discriminator_scope)
- else:
- testcase.assertEqual(discriminator_scope_name,
- gan_model.discriminator_scope.name)
-
- with ops.control_dependencies(assertions):
- if mode == model_fn_lib.ModeKeys.TRAIN:
- return model_fn_lib.EstimatorSpec(
- mode=mode, loss=array_ops.zeros([]),
- train_op=control_flow_ops.no_op(), training_hooks=[])
- elif mode == model_fn_lib.ModeKeys.EVAL:
- return model_fn_lib.EstimatorSpec(
- mode=mode, predictions=gan_model.generated_data,
- loss=array_ops.zeros([]))
- elif mode == model_fn_lib.ModeKeys.PREDICT:
- return model_fn_lib.EstimatorSpec(
- mode=mode, predictions=gan_model.generated_data)
- else:
- testcase.fail('Invalid mode: {}'.format(mode))
-
- head = test.mock.NonCallableMagicMock(spec=head_lib._Head)
- head.create_estimator_spec = test.mock.MagicMock(
- wraps=_create_estimator_spec)
-
- return head
-
-
-class GANModelFnTest(test.TestCase):
- """Tests that _gan_model_fn passes expected logits to mock head."""
-
- def setUp(self):
- self._model_dir = tempfile.mkdtemp()
-
- def tearDown(self):
- if self._model_dir:
- writer_cache.FileWriterCache.clear()
- shutil.rmtree(self._model_dir)
+class GetGANModelTest(test.TestCase, parameterized.TestCase):
+ """Tests that `GetGANModel` produces the correct model."""
- def _test_logits_helper(self, mode):
- """Tests that the expected logits are passed to mock head."""
+ @parameterized.named_parameters(
+ ('train', model_fn_lib.ModeKeys.TRAIN),
+ ('eval', model_fn_lib.ModeKeys.EVAL),
+ ('predict', model_fn_lib.ModeKeys.PREDICT))
+ def test_get_gan_model(self, mode):
with ops.Graph().as_default():
- training_util.get_or_create_global_step()
- generator_inputs = {'x': array_ops.zeros([5, 4])}
- real_data = (None if mode == model_fn_lib.ModeKeys.PREDICT else
- array_ops.zeros([5, 4]))
- generator_scope_name = 'generator'
- head = mock_head(self,
- expected_generator_inputs=generator_inputs,
- expected_real_data=real_data,
- generator_scope_name=generator_scope_name)
- estimator_spec = estimator._gan_model_fn(
- features=generator_inputs,
- labels=real_data,
- mode=mode,
- generator_fn=generator_fn,
- discriminator_fn=discriminator_fn,
- generator_scope_name=generator_scope_name,
- head=head)
- with monitored_session.MonitoredTrainingSession(
- checkpoint_dir=self._model_dir) as sess:
- if mode == model_fn_lib.ModeKeys.TRAIN:
- sess.run(estimator_spec.train_op)
- elif mode == model_fn_lib.ModeKeys.EVAL:
- sess.run(estimator_spec.loss)
- elif mode == model_fn_lib.ModeKeys.PREDICT:
- sess.run(estimator_spec.predictions)
- else:
- self.fail('Invalid mode: {}'.format(mode))
-
- def test_logits_predict(self):
- self._test_logits_helper(model_fn_lib.ModeKeys.PREDICT)
-
- def test_logits_eval(self):
- self._test_logits_helper(model_fn_lib.ModeKeys.EVAL)
-
- def test_logits_train(self):
- self._test_logits_helper(model_fn_lib.ModeKeys.TRAIN)
+ generator_inputs = {'x': array_ops.ones([3, 4])}
+ real_data = (array_ops.zeros([3, 4]) if
+ mode != model_fn_lib.ModeKeys.PREDICT else None)
+ gan_model = estimator._get_gan_model(
+ mode, generator_fn, discriminator_fn, real_data, generator_inputs,
+ add_summaries=False)
+
+ self.assertEqual(generator_inputs, gan_model.generator_inputs)
+ self.assertIsNotNone(gan_model.generated_data)
+ self.assertEqual(2, len(gan_model.generator_variables)) # 1 FC layer
+ self.assertIsNotNone(gan_model.generator_fn)
+ if mode == model_fn_lib.ModeKeys.PREDICT:
+ self.assertIsNone(gan_model.real_data)
+ self.assertIsNone(gan_model.discriminator_real_outputs)
+ self.assertIsNone(gan_model.discriminator_gen_outputs)
+ self.assertIsNone(gan_model.discriminator_variables)
+ self.assertIsNone(gan_model.discriminator_scope)
+ self.assertIsNone(gan_model.discriminator_fn)
+ else:
+ self.assertIsNotNone(gan_model.real_data)
+ self.assertIsNotNone(gan_model.discriminator_real_outputs)
+ self.assertIsNotNone(gan_model.discriminator_gen_outputs)
+ self.assertEqual(2, len(gan_model.discriminator_variables)) # 1 FC layer
+ self.assertIsNotNone(gan_model.discriminator_scope)
+ self.assertIsNotNone(gan_model.discriminator_fn)
+
+
+def get_dummy_gan_model():
+ # TODO(joelshor): Find a better way of creating a variable scope.
+ with variable_scope.variable_scope('generator') as gen_scope:
+ gen_var = variable_scope.get_variable('dummy_var', initializer=0.0)
+ with variable_scope.variable_scope('discriminator') as dis_scope:
+ dis_var = variable_scope.get_variable('dummy_var', initializer=0.0)
+ return tfgan_tuples.GANModel(
+ generator_inputs=None,
+ generated_data=array_ops.ones([3, 4]),
+ generator_variables=[gen_var],
+ generator_scope=gen_scope,
+ generator_fn=None,
+ real_data=array_ops.zeros([3, 4]),
+ discriminator_real_outputs=array_ops.ones([1, 2, 3]) * dis_var,
+ discriminator_gen_outputs=array_ops.ones([1, 2, 3]) * gen_var * dis_var,
+ discriminator_variables=[dis_var],
+ discriminator_scope=dis_scope,
+ discriminator_fn=None)
+
+
+def dummy_loss_fn(gan_model):
+ return math_ops.reduce_sum(gan_model.discriminator_real_outputs -
+ gan_model.discriminator_gen_outputs)
+
+
+def get_metrics(gan_model):
+ return {
+ 'mse_custom_metric': metrics_lib.mean_squared_error(
+ gan_model.real_data, gan_model.generated_data)
+ }
+
+
+class GetEstimatorSpecTest(test.TestCase, parameterized.TestCase):
+ """Tests that the EstimatorSpec is constructed appropriately."""
+
+ @classmethod
+ def setUpClass(cls):
+ cls._generator_optimizer = training.GradientDescentOptimizer(1.0)
+ cls._discriminator_optimizer = training.GradientDescentOptimizer(1.0)
+
+ @parameterized.named_parameters(
+ ('train', model_fn_lib.ModeKeys.TRAIN),
+ ('eval', model_fn_lib.ModeKeys.EVAL),
+ ('predict', model_fn_lib.ModeKeys.PREDICT))
+ def test_get_estimator_spec(self, mode):
+ with ops.Graph().as_default():
+ self._gan_model = get_dummy_gan_model()
+ spec = estimator._get_estimator_spec(
+ mode,
+ self._gan_model,
+ generator_loss_fn=dummy_loss_fn,
+ discriminator_loss_fn=dummy_loss_fn,
+ get_eval_metric_ops_fn=get_metrics,
+ generator_optimizer=self._generator_optimizer,
+ discriminator_optimizer=self._discriminator_optimizer)
+
+ self.assertEqual(mode, spec.mode)
+ if mode == model_fn_lib.ModeKeys.PREDICT:
+ self.assertEqual(self._gan_model.generated_data, spec.predictions)
+ elif mode == model_fn_lib.ModeKeys.TRAIN:
+ self.assertShapeEqual(np.array(0), spec.loss) # must be a scalar
+ self.assertIsNotNone(spec.train_op)
+ self.assertIsNotNone(spec.training_hooks)
+ elif mode == model_fn_lib.ModeKeys.EVAL:
+ self.assertEqual(self._gan_model.generated_data, spec.predictions)
+ self.assertShapeEqual(np.array(0), spec.loss) # must be a scalar
+ self.assertIsNotNone(spec.eval_metric_ops)
# TODO(joelshor): Add pandas test.
@@ -195,12 +184,6 @@ class GANEstimatorIntegrationTest(test.TestCase):
lr = learning_rate_decay.exponential_decay(1.0, gstep, 10, 0.9)
return training.GradientDescentOptimizer(lr)
- def get_metrics(gan_model):
- return {
- 'mse_custom_metric': metrics_lib.mean_squared_error(
- gan_model.real_data, gan_model.generated_data)
- }
-
gopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0)
dopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0)
est = estimator.GANEstimator(
diff --git a/tensorflow/contrib/gan/python/estimator/python/head_impl.py b/tensorflow/contrib/gan/python/estimator/python/head_impl.py
index d1441e1eb2..1a0ee6dfc4 100644
--- a/tensorflow/contrib/gan/python/estimator/python/head_impl.py
+++ b/tensorflow/contrib/gan/python/estimator/python/head_impl.py
@@ -27,16 +27,21 @@ from tensorflow.python.estimator.canned import head
from tensorflow.python.estimator.export import export_output
from tensorflow.python.framework import ops
from tensorflow.python.ops import metrics as metrics_lib
+from tensorflow.python.util import deprecation
__all__ = [
'GANHead',
'gan_head',
]
+
def _summary_key(head_name, val):
return '%s/%s' % (val, head_name) if head_name else val
+@deprecation.deprecated(
+ None, 'Please use tf.contrib.gan.GANEstimator without explicitly making a '
+ 'GANHead.')
def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer,
discriminator_optimizer, use_loss_summaries=True,
get_hooks_fn=tfgan_train.get_sequential_train_hooks(),
@@ -77,6 +82,9 @@ def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer,
class GANHead(head._Head): # pylint: disable=protected-access
"""`Head` for a GAN."""
+ @deprecation.deprecated(
+ None, 'Please use tf.contrib.gan.GANEstimator without explicitly making '
+ 'a GANHead.')
def __init__(self, generator_loss_fn, discriminator_loss_fn,
generator_optimizer, discriminator_optimizer,
use_loss_summaries=True,
@@ -108,7 +116,7 @@ class GANHead(head._Head): # pylint: disable=protected-access
raise TypeError('generator_loss_fn must be callable.')
if not callable(discriminator_loss_fn):
raise TypeError('discriminator_loss_fn must be callable.')
- if not use_loss_summaries in [True, False, None]:
+ if use_loss_summaries not in [True, False, None]:
raise ValueError('use_loss_summaries must be True, False or None.')
if get_hooks_fn is not None and not callable(get_hooks_fn):
raise TypeError('get_hooks_fn must be callable.')
diff --git a/tensorflow/contrib/gan/python/estimator/python/head_test.py b/tensorflow/contrib/gan/python/estimator/python/head_test.py
index 5309d87765..8205bc889d 100644
--- a/tensorflow/contrib/gan/python/estimator/python/head_test.py
+++ b/tensorflow/contrib/gan/python/estimator/python/head_test.py
@@ -67,7 +67,7 @@ class GANHeadTest(test.TestCase):
generator_optimizer=training.GradientDescentOptimizer(1.0),
discriminator_optimizer=training.GradientDescentOptimizer(1.0),
get_eval_metric_ops_fn=self.get_metrics)
- self.assertTrue(isinstance(self.gan_head, head.GANHead))
+ self.assertIsInstance(self.gan_head, head.GANHead)
def get_metrics(self, gan_model):
self.assertTrue(isinstance(gan_model, tfgan_tuples.GANModel))
diff --git a/tensorflow/contrib/kafka/ops/kafka_ops.cc b/tensorflow/contrib/kafka/ops/kafka_ops.cc
new file mode 100644
index 0000000000..8cdf16103b
--- /dev/null
+++ b/tensorflow/contrib/kafka/ops/kafka_ops.cc
@@ -0,0 +1,44 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+
+REGISTER_OP("KafkaDataset")
+ .Input("topics: string")
+ .Input("servers: string")
+ .Input("group: string")
+ .Input("eof: bool")
+ .Input("timeout: int64")
+ .Output("handle: variant")
+ .SetIsStateful()
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+Creates a dataset that emits the messages of one or more Kafka topics.
+
+topics: A `tf.string` tensor containing one or more subscriptions,
+ in the format of [topic:partition:offset:length],
+ by default length is -1 for unlimited.
+servers: A list of bootstrap servers.
+group: The consumer group id.
+eof: If True, the kafka reader will stop on EOF.
+timeout: The timeout value for the Kafka Consumer to wait
+ (in millisecond).
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib.py b/tensorflow/contrib/layers/python/layers/rev_block_lib.py
index 0e35b1aa8b..dad3da3748 100644
--- a/tensorflow/contrib/layers/python/layers/rev_block_lib.py
+++ b/tensorflow/contrib/layers/python/layers/rev_block_lib.py
@@ -514,15 +514,15 @@ def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
original_vars = set(tape.watched_variables())
# Backward pass
- def grad_fn(*output_grads, **kwargs):
+ def _grad_fn(output_grads, variables=None):
"""Recompute outputs for gradient computation."""
- variables = []
+ variables = variables or []
if original_vars:
- variables = kwargs["variables"]
- if set(variables) != original_vars:
- raise ValueError(_WRONG_VARS_ERR)
- del kwargs
- inputs = list(args)
+ assert variables, ("Fn created variables but the variables were not "
+ "passed to the gradient fn.")
+ if set(variables) != original_vars:
+ raise ValueError(_WRONG_VARS_ERR)
+ inputs = [array_ops.identity(x) for x in list(args)]
# Recompute outputs
with framework_ops.control_dependencies(output_grads):
if use_data_dep_:
@@ -538,7 +538,7 @@ def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
if original_vars != recompute_vars:
raise ValueError(_WRONG_VARS_ERR)
- if not (isinstance(outputs, list) or isinstance(outputs, tuple)):
+ if not isinstance(outputs, (list, tuple)):
outputs = [outputs]
outputs = list(outputs)
grads = gradients_impl.gradients(outputs, inputs + variables,
@@ -554,6 +554,16 @@ def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
grad_vars = grads[len(inputs):]
return grad_inputs, grad_vars
+ # custom_gradient inspects the signature of the function to determine
+ # whether the user expects variables passed in the grad_fn. If the function
+ # created variables, the grad_fn should accept the "variables" kwarg.
+ if original_vars:
+ def grad_fn(*output_grads, **kwargs):
+ return _grad_fn(output_grads, kwargs["variables"])
+ else:
+ def grad_fn(*output_grads):
+ return _grad_fn(output_grads)
+
return outputs, grad_fn
return fn_with_recompute(*args)
diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
index bc09ba8d43..d5971fb9d8 100644
--- a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
+++ b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
@@ -372,6 +372,26 @@ class RecomputeTest(test.TestCase):
self.assertEqual(2, len(update_ops))
self.assertEqual([False, True], kwarg_values)
+ def testWithoutVariables(self):
+
+ def concat_n(layer_list, num_inputs):
+ return math_ops.reduce_sum(
+ array_ops.concat([x for x in layer_list[-num_inputs:]], axis=-1),
+ axis=1, keepdims=True)
+
+ @rev_block_lib.recompute_grad
+ def concat_n_wrap(*args):
+ return concat_n(args, 3)
+
+ # DenseNet-style layers
+ layer_list = [random_ops.random_uniform((4, 8))]
+ for _ in range(5):
+ layer_list.append(math_ops.sqrt(concat_n_wrap(*layer_list)))
+
+ grads = gradients_impl.gradients(layer_list[-1], layer_list[0])
+ with self.test_session() as sess:
+ sess.run(grads)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py
index 5e7b422e3c..e742447208 100644
--- a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py
+++ b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py
@@ -625,11 +625,13 @@ def attention_decoder(decoder_inputs,
v = []
attention_vec_size = attn_size # Size of query vectors for attention.
for a in xrange(num_heads):
- k = variable_scope.get_variable("AttnW_%d" % a,
- [1, 1, attn_size, attention_vec_size])
+ k = variable_scope.get_variable(
+ "AttnW_%d" % a, [1, 1, attn_size, attention_vec_size],
+ dtype=dtype)
hidden_features.append(nn_ops.conv2d(hidden, k, [1, 1, 1, 1], "SAME"))
v.append(
- variable_scope.get_variable("AttnV_%d" % a, [attention_vec_size]))
+ variable_scope.get_variable(
+ "AttnV_%d" % a, [attention_vec_size], dtype=dtype))
state = initial_state
@@ -647,11 +649,13 @@ def attention_decoder(decoder_inputs,
with variable_scope.variable_scope("Attention_%d" % a):
y = Linear(query, attention_vec_size, True)(query)
y = array_ops.reshape(y, [-1, 1, 1, attention_vec_size])
+ y = math_ops.cast(y, dtype)
# Attention mask is a softmax of v^T * tanh(...).
s = math_ops.reduce_sum(v[a] * math_ops.tanh(hidden_features[a] + y),
[2, 3])
- a = nn_ops.softmax(s)
+ a = nn_ops.softmax(math_ops.cast(s, dtype=dtypes.float32))
# Now calculate the attention-weighted vector d.
+ a = math_ops.cast(a, dtype)
d = math_ops.reduce_sum(
array_ops.reshape(a, [-1, attn_length, 1, 1]) * hidden, [1, 2])
ds.append(array_ops.reshape(d, [-1, attn_size]))
@@ -681,6 +685,7 @@ def attention_decoder(decoder_inputs,
raise ValueError("Could not infer input size from input: %s" % inp.name)
inputs = [inp] + attns
+ inputs = [math_ops.cast(e, dtype) for e in inputs]
x = Linear(inputs, input_size, True)(inputs)
# Run the RNN.
cell_output, state = cell(x, state)
@@ -693,6 +698,7 @@ def attention_decoder(decoder_inputs,
attns = attention(state)
with variable_scope.variable_scope("AttnOutputProjection"):
+ cell_output = math_ops.cast(cell_output, dtype)
inputs = [cell_output] + attns
output = Linear(inputs, output_size, True)(inputs)
if loop_function is not None:
diff --git a/tensorflow/contrib/linear_optimizer/BUILD b/tensorflow/contrib/linear_optimizer/BUILD
index 5b89c6cef9..fe0ba19fcb 100644
--- a/tensorflow/contrib/linear_optimizer/BUILD
+++ b/tensorflow/contrib/linear_optimizer/BUILD
@@ -41,6 +41,7 @@ py_test(
size = "medium",
srcs = ["python/kernel_tests/sdca_ops_test.py"],
srcs_version = "PY2AND3",
+ tags = ["no_windows_gpu"],
deps = [
":sdca_ops_py",
":sparse_feature_column_py",
diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD
index 73f5c1448d..b95d4d0fce 100644
--- a/tensorflow/contrib/lite/BUILD
+++ b/tensorflow/contrib/lite/BUILD
@@ -146,6 +146,7 @@ cc_library(
":memory_planner",
":schema_fbs_version",
":simple_memory_arena",
+ ":string",
":util",
"//tensorflow/contrib/lite/kernels:eigen_support",
"//tensorflow/contrib/lite/kernels:gemm_support",
diff --git a/tensorflow/contrib/lite/allocation.h b/tensorflow/contrib/lite/allocation.h
index 68aee2e644..827ea86503 100644
--- a/tensorflow/contrib/lite/allocation.h
+++ b/tensorflow/contrib/lite/allocation.h
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/context.h"
#include "tensorflow/contrib/lite/error_reporter.h"
#include "tensorflow/contrib/lite/simple_memory_arena.h"
+#include "tensorflow/contrib/lite/string.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/arena_planner.cc b/tensorflow/contrib/lite/arena_planner.cc
index 4257e754ad..16a0e71624 100644
--- a/tensorflow/contrib/lite/arena_planner.cc
+++ b/tensorflow/contrib/lite/arena_planner.cc
@@ -36,12 +36,13 @@ struct AllocationInfo {
ArenaPlanner::ArenaPlanner(TfLiteContext* context,
std::unique_ptr<GraphInfo> graph_info,
- bool preserve_inputs)
+ bool preserve_inputs, bool preserve_intermediates)
: context_(context),
graph_info_(std::move(graph_info)),
arena_(kDefaultArenaAlignment),
persistent_arena_(kDefaultArenaAlignment),
- preserve_inputs_(preserve_inputs) {}
+ preserve_inputs_(preserve_inputs),
+ preserve_intermediates_(preserve_intermediates) {}
ArenaPlanner::~ArenaPlanner() {}
int64_t ArenaPlanner::BasePointer(TfLiteAllocationType type) {
@@ -164,13 +165,15 @@ TfLiteStatus ArenaPlanner::PlanAllocations() {
// Then update the ref-counts of the node's inputs, and if necessary queue
// them for deallocation.
- TfLiteIntArray* node_inputs = node.inputs;
- for (int j = 0; j < node_inputs->size; ++j) {
- int tensor_index = node_inputs->data[j];
- if (tensor_index != kOptionalTensor) {
- refcounts[tensor_index]--;
- if (refcounts[tensor_index] == 0) {
- TF_LITE_ENSURE_STATUS(deallocate(i, tensor_index));
+ if (!preserve_intermediates_) {
+ TfLiteIntArray* node_inputs = node.inputs;
+ for (int j = 0; j < node_inputs->size; ++j) {
+ int tensor_index = node_inputs->data[j];
+ if (tensor_index != kOptionalTensor) {
+ refcounts[tensor_index]--;
+ if (refcounts[tensor_index] == 0) {
+ TF_LITE_ENSURE_STATUS(deallocate(i, tensor_index));
+ }
}
}
}
diff --git a/tensorflow/contrib/lite/arena_planner.h b/tensorflow/contrib/lite/arena_planner.h
index 1d84950e91..82c866734f 100644
--- a/tensorflow/contrib/lite/arena_planner.h
+++ b/tensorflow/contrib/lite/arena_planner.h
@@ -47,7 +47,7 @@ class ArenaPlanner : public MemoryPlanner {
// graph will not share memory with any other tensor, effectively preserving
// them until the end of inference.
ArenaPlanner(TfLiteContext* context, std::unique_ptr<GraphInfo> graph_info,
- bool preserve_inputs);
+ bool preserve_inputs, bool preserve_intermediates);
~ArenaPlanner() override;
ArenaPlanner(const ArenaPlanner&) = delete;
ArenaPlanner& operator=(const ArenaPlanner&) = delete;
@@ -104,7 +104,14 @@ class ArenaPlanner : public MemoryPlanner {
// declared as kTfLiteArenaRwPersistent.
SimpleMemoryArena persistent_arena_;
+ // Ensure that the memory self-allocated for inputs is never reused by the
+ // allocator. This allows for example, multiple runs without getting
+ // unpredictable results.
bool preserve_inputs_;
+
+ // If true, then no overlapping of memory areas is done, meaning intermediates
+ // results can be queried after running (modulo running delegates).
+ bool preserve_intermediates_;
};
} // namespace tflite
diff --git a/tensorflow/contrib/lite/arena_planner_test.cc b/tensorflow/contrib/lite/arena_planner_test.cc
index f5bd1932f9..1adb426d58 100644
--- a/tensorflow/contrib/lite/arena_planner_test.cc
+++ b/tensorflow/contrib/lite/arena_planner_test.cc
@@ -156,7 +156,7 @@ class ArenaPlannerTest : public ::testing::Test {
context_.ReportError = ReportError;
planner_.reset(new ArenaPlanner(
&context_, std::unique_ptr<GraphInfo>(new TestGraphInfo(graph)),
- preserve_inputs));
+ preserve_inputs, /*preserve intermediates*/ false));
CHECK(planner_->ResetAllocations() == kTfLiteOk);
CHECK(planner_->PlanAllocations() == kTfLiteOk);
}
diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl
index 5543acc1f5..b735d08b4b 100644
--- a/tensorflow/contrib/lite/build_def.bzl
+++ b/tensorflow/contrib/lite/build_def.bzl
@@ -195,7 +195,7 @@ def json_to_tflite(name, src, out):
def generated_test_models():
return [
"add",
- "arg_max",
+ "arg_min_max",
"avg_pool",
"batch_to_space_nd",
"concat",
@@ -232,7 +232,7 @@ def generated_test_models():
"not_equal",
"pad",
"padv2",
- # "prelu",
+ "prelu",
"pow",
"relu",
"relu1",
@@ -257,7 +257,7 @@ def generated_test_models():
"tile",
"topk",
"transpose",
- "transpose_conv",
+ #"transpose_conv", # disabled due to b/111213074
"where",
]
diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h
index cda889bf50..a24aaad7dd 100644
--- a/tensorflow/contrib/lite/builtin_op_data.h
+++ b/tensorflow/contrib/lite/builtin_op_data.h
@@ -250,6 +250,10 @@ typedef struct {
} TfLiteArgMaxParams;
typedef struct {
+ TfLiteType output_type;
+} TfLiteArgMinParams;
+
+typedef struct {
TfLitePadding padding;
int stride_width;
int stride_height;
@@ -263,6 +267,16 @@ typedef struct {
TfLiteType out_type;
} TfLiteShapeParams;
+typedef struct {
+ // Parameters supported by version 1:
+ float min;
+ float max;
+ int num_bits;
+
+ // Parameters supported by version 2:
+ bool narrow_range;
+} TfLiteFakeQuantParams;
+
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h
index a44e918230..6bde5d2e6d 100644
--- a/tensorflow/contrib/lite/builtin_ops.h
+++ b/tensorflow/contrib/lite/builtin_ops.h
@@ -104,6 +104,8 @@ typedef enum {
kTfLiteBuiltinRsqrt = 76,
kTfLiteBuiltinShape = 77,
kTfLiteBuiltinPow = 78,
+ kTfLiteBuiltinArgMin = 79,
+ kTfLiteBuiltinFakeQuant = 80,
} TfLiteBuiltinOperator;
#ifdef __cplusplus
diff --git a/tensorflow/contrib/lite/delegates/eager/BUILD b/tensorflow/contrib/lite/delegates/eager/BUILD
new file mode 100644
index 0000000000..066b106215
--- /dev/null
+++ b/tensorflow/contrib/lite/delegates/eager/BUILD
@@ -0,0 +1,35 @@
+#
+# This is a TF Lite delegate that is powered by TensorFlow's Eager.
+#
+package(default_visibility = [
+ "//visibility:public",
+])
+
+licenses(["notice"]) # Apache 2.0
+
+cc_library(
+ name = "util",
+ srcs = ["util.cc"],
+ hdrs = ["util.h"],
+ deps = [
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite:kernel_api",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_test(
+ name = "util_test",
+ size = "small",
+ srcs = ["util_test.cc"],
+ tags = [
+ "tflite_not_portable",
+ ],
+ deps = [
+ ":util",
+ "//tensorflow/contrib/lite/testing:util",
+ "//tensorflow/core:lib",
+ "@com_google_googletest//:gtest",
+ ],
+)
diff --git a/tensorflow/contrib/lite/delegates/eager/util.cc b/tensorflow/contrib/lite/delegates/eager/util.cc
new file mode 100644
index 0000000000..04a852e515
--- /dev/null
+++ b/tensorflow/contrib/lite/delegates/eager/util.cc
@@ -0,0 +1,47 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/delegates/eager/util.h"
+
+namespace tflite {
+
+TfLiteStatus ConvertStatus(TfLiteContext* context,
+ const tensorflow::Status& status) {
+ if (!status.ok()) {
+ context->ReportError(context, "%s", status.error_message().c_str());
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus CopyShape(TfLiteContext* context, const tensorflow::Tensor& src,
+ TfLiteTensor* tensor) {
+ int num_dims = src.dims();
+ TfLiteIntArray* shape = TfLiteIntArrayCreate(num_dims);
+ for (int j = 0; j < num_dims; ++j) {
+ // We need to cast from TensorFlow's int64 to TF Lite's int32. Let's
+ // make sure there's no overflow.
+ if (src.dim_size(j) >= std::numeric_limits<int>::max()) {
+ context->ReportError(context,
+ "Dimension value in TensorFlow shape is larger than "
+ "supported by TF Lite");
+ TfLiteIntArrayFree(shape);
+ return kTfLiteError;
+ }
+ shape->data[j] = static_cast<int>(src.dim_size(j));
+ }
+ return context->ResizeTensor(context, tensor, shape);
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/delegates/eager/util.h b/tensorflow/contrib/lite/delegates/eager/util.h
new file mode 100644
index 0000000000..2696ca8d0d
--- /dev/null
+++ b/tensorflow/contrib/lite/delegates/eager/util.h
@@ -0,0 +1,35 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_UTIL_H_
+#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_UTIL_H_
+
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tflite {
+
+// Converts a tensorflow:Status into a TfLiteStatus. If the original status
+// represented an error, reports it using the given 'context'.
+TfLiteStatus ConvertStatus(TfLiteContext* context,
+ const tensorflow::Status& status);
+
+// Copies the given shape of the given 'src' into a TF Lite 'tensor'. Logs an
+// error and returns kTfLiteError if the shape can't be converted.
+TfLiteStatus CopyShape(TfLiteContext* context, const tensorflow::Tensor& src,
+ TfLiteTensor* tensor);
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_UTIL_H_
diff --git a/tensorflow/contrib/lite/delegates/eager/util_test.cc b/tensorflow/contrib/lite/delegates/eager/util_test.cc
new file mode 100644
index 0000000000..563f82dec3
--- /dev/null
+++ b/tensorflow/contrib/lite/delegates/eager/util_test.cc
@@ -0,0 +1,100 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/delegates/eager/util.h"
+
+#include <cstdarg>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/testing/util.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAre;
+
+struct TestContext : public TfLiteContext {
+ string error;
+ std::vector<int> new_size;
+};
+
+void ReportError(TfLiteContext* context, const char* format, ...) {
+ TestContext* c = static_cast<TestContext*>(context);
+ const size_t kBufferSize = 1024;
+ char temp_buffer[kBufferSize];
+
+ va_list args;
+ va_start(args, format);
+ vsnprintf(temp_buffer, kBufferSize, format, args);
+ va_end(args);
+
+ c->error = temp_buffer;
+}
+
+TfLiteStatus ResizeTensor(TfLiteContext* context, TfLiteTensor* tensor,
+ TfLiteIntArray* new_size) {
+ TestContext* c = static_cast<TestContext*>(context);
+ c->new_size.clear();
+ for (int i = 0; i < new_size->size; ++i) {
+ c->new_size.push_back(new_size->data[i]);
+ }
+ TfLiteIntArrayFree(new_size);
+ return kTfLiteOk;
+}
+
+TEST(UtilTest, ConvertStatus) {
+ TestContext context;
+ context.ReportError = ReportError;
+
+ EXPECT_EQ(ConvertStatus(&context, tensorflow::errors::Internal("Some Error")),
+ kTfLiteError);
+ EXPECT_EQ(context.error, "Some Error");
+
+ context.error.clear();
+ EXPECT_EQ(ConvertStatus(&context, tensorflow::Status()), kTfLiteOk);
+ EXPECT_TRUE(context.error.empty());
+}
+
+TEST(UtilTest, CopyShape) {
+ TestContext context;
+ context.ReportError = ReportError;
+ context.ResizeTensor = ResizeTensor;
+
+ using tensorflow::DT_FLOAT;
+ using tensorflow::Tensor;
+
+ TfLiteTensor dst;
+
+ EXPECT_EQ(CopyShape(&context, Tensor(), &dst), kTfLiteOk);
+ EXPECT_THAT(context.new_size, ElementsAre(0));
+
+ EXPECT_EQ(CopyShape(&context, Tensor(DT_FLOAT, {1, 2}), &dst), kTfLiteOk);
+ EXPECT_THAT(context.new_size, ElementsAre(1, 2));
+
+ EXPECT_EQ(CopyShape(&context, Tensor(DT_FLOAT, {1LL << 44, 2}), &dst),
+ kTfLiteError);
+ EXPECT_EQ(context.error,
+ "Dimension value in TensorFlow shape is larger than supported by "
+ "TF Lite");
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc
index fd798c209e..f0d16575ec 100644
--- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc
@@ -452,6 +452,22 @@ class NNAPIDelegateKernel {
} else {
return nullptr;
}
+ case kTfLiteBuiltinTranspose:
+ // Transpose requires NNAPI1.1. Also note that the permutation input
+ // tensor value dictates the output dimensions.
+ // TODO(b/110888333): Support dynamically-sized tensors in delegates.
+ if ((version == 1) &&
+ (kAndroidSdkVersion >= kMinSdkVersionForNNAPI11) &&
+ (node->inputs->size > 1) &&
+ (context->tensors[node->inputs->data[1]].allocation_type ==
+ kTfLiteMmapRo)) {
+ return [](TfLiteContext* context, NNAPIOpBuilder* builder,
+ TfLiteNode* node) -> ANeuralNetworksOperationType {
+ return ANEURALNETWORKS_TRANSPOSE;
+ };
+ } else {
+ return nullptr;
+ }
break;
default:
return nullptr;
diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc
index aad10c9ce7..ab2181e8ff 100644
--- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc
+++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc
@@ -27,14 +27,20 @@ using ::testing::ElementsAreArray;
// TODO(b/110368244): figure out how to share the existing tests in kernels/ but
// with the delegation on. Also, add more unit tests to improve code coverage.
-class FloatAddOpModel : public SingleOpModel {
+class SingleOpModelWithNNAPI : public SingleOpModel {
+ public:
+ SingleOpModelWithNNAPI() {
+ this->SetApplyDelegate([](Interpreter* interpreter) {
+ interpreter->ModifyGraphWithDelegate(NnApiDelegate(), false);
+ });
+ }
+};
+
+class FloatAddOpModel : public SingleOpModelWithNNAPI {
public:
FloatAddOpModel(const TensorData& input1, const TensorData& input2,
const TensorData& output,
ActivationFunctionType activation_type) {
- this->SetApplyDelegate([](Interpreter* interpreter) {
- interpreter->ModifyGraphWithDelegate(NnApiDelegate());
- });
input1_ = AddInput(input1);
input2_ = AddInput(input2);
output_ = AddOutput(output);
@@ -81,9 +87,6 @@ class FloatMulOpModel : public SingleOpModel {
FloatMulOpModel(const TensorData& input1, const TensorData& input2,
const TensorData& output,
ActivationFunctionType activation_type) {
- this->SetApplyDelegate([](Interpreter* interpreter) {
- interpreter->ModifyGraphWithDelegate(NnApiDelegate());
- });
input1_ = AddInput(input1);
input2_ = AddInput(input2);
output_ = AddOutput(output);
@@ -114,15 +117,11 @@ TEST(NNAPIDelegate, MulWithNoActivation) {
ElementsAreArray(ArrayFloatNear({-0.2, 0.04, 0.21, 0.4})));
}
-class FloatPoolingOpModel : public SingleOpModel {
+class FloatPoolingOpModel : public SingleOpModelWithNNAPI {
public:
FloatPoolingOpModel(BuiltinOperator type, const TensorData& input,
int filter_width, int filter_height,
const TensorData& output) {
- this->SetApplyDelegate([](Interpreter* interpreter) {
- interpreter->ModifyGraphWithDelegate(NnApiDelegate());
- });
-
input_ = AddInput(input);
output_ = AddOutput(output);
@@ -193,10 +192,6 @@ class BaseConvolutionOpModel : public SingleOpModel {
enum Padding padding = Padding_VALID,
enum ActivationFunctionType activation = ActivationFunctionType_NONE,
int dilation_width_factor = 1, int dilation_height_factor = 1) {
- this->SetApplyDelegate([](Interpreter* interpreter) {
- interpreter->ModifyGraphWithDelegate(NnApiDelegate());
- });
-
input_ = AddInput(input);
filter_ = AddInput(filter);
@@ -344,14 +339,10 @@ TEST(NNAPIDelegate, Conv2DWithNoActivation) {
}));
}
-class DepthwiseConvolutionOpModel : public SingleOpModel {
+class DepthwiseConvolutionOpModel : public SingleOpModelWithNNAPI {
public:
DepthwiseConvolutionOpModel(const TensorData& input, const TensorData& filter,
const TensorData& output) {
- this->SetApplyDelegate([](Interpreter* interpreter) {
- interpreter->ModifyGraphWithDelegate(NnApiDelegate());
- });
-
input_ = AddInput(input);
filter_ = AddInput(filter);
@@ -426,15 +417,11 @@ TEST(NNAPIDelegate, DepthwiseConv2DWithNoActivation) {
}));
}
-class FloatFullyConnectedOpModel : public SingleOpModel {
+class FloatFullyConnectedOpModel : public SingleOpModelWithNNAPI {
public:
FloatFullyConnectedOpModel(int units, int batches, const TensorData& input,
const TensorData& output = {TensorType_FLOAT32})
: batches_(batches), units_(units) {
- this->SetApplyDelegate([](Interpreter* interpreter) {
- interpreter->ModifyGraphWithDelegate(NnApiDelegate());
- });
-
int total_input_size = 1;
for (int i = 0; i < input.shape.size(); ++i) {
total_input_size *= input.shape[i];
@@ -515,14 +502,10 @@ TEST(NNAPIDelegate, FullyConnectedSimpleTest) {
EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60));
}
-class SoftmaxOpModel : public SingleOpModel {
+class SoftmaxOpModel : public SingleOpModelWithNNAPI {
public:
SoftmaxOpModel(int batches, int size, float beta)
: batches_(batches), input_size_(size), beta_(beta) {
- this->SetApplyDelegate([](Interpreter* interpreter) {
- interpreter->ModifyGraphWithDelegate(NnApiDelegate());
- });
-
input_ = AddInput(TensorType_FLOAT32);
output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp(BuiltinOperator_SOFTMAX, BuiltinOptions_SoftmaxOptions,
@@ -566,14 +549,10 @@ TEST(NNAPIDelegate, SoftmaxSimpleTest) {
1e-6)));
}
-class ReshapeOpModel : public SingleOpModel {
+class ReshapeOpModel : public SingleOpModelWithNNAPI {
public:
ReshapeOpModel(std::initializer_list<int> input_shape,
std::initializer_list<int> new_shape) {
- this->SetApplyDelegate([](Interpreter* interpreter) {
- interpreter->ModifyGraphWithDelegate(NnApiDelegate());
- });
-
input_ = AddInput(TensorType_FLOAT32);
new_shape_ = AddInput(TensorType_INT32);
output_ = AddOutput(TensorType_FLOAT32);
@@ -605,14 +584,10 @@ TEST(NNAPIDelegate, ReshapeSimpleTest) {
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2}));
}
-class SqueezeOpModel : public SingleOpModel {
+class SqueezeOpModel : public SingleOpModelWithNNAPI {
public:
SqueezeOpModel(const TensorData& input, const TensorData& output,
std::initializer_list<int> axis) {
- this->SetApplyDelegate([](Interpreter* interpreter) {
- interpreter->ModifyGraphWithDelegate(NnApiDelegate());
- });
-
input_ = AddInput(input);
output_ = AddOutput(output);
SetBuiltinOp(
@@ -666,6 +641,43 @@ TEST(NNAPIDelegate, SqueezeWithAxisTest) {
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}));
}
+class TransposeSimpleModel : public SingleOpModelWithNNAPI {
+ public:
+ TransposeSimpleModel(std::initializer_list<int> input_shape,
+ std::initializer_list<int> perm_shape,
+ std::initializer_list<int> perm) {
+ input_ = AddInput(TensorType_FLOAT32);
+ perm_ = AddConstInput(TensorType_INT32, perm, perm_shape);
+ output_ = AddOutput(TensorType_FLOAT32);
+ SetBuiltinOp(BuiltinOperator_TRANSPOSE, BuiltinOptions_TransposeOptions,
+ CreateTransposeOptions(builder_).Union());
+ BuildInterpreter({input_shape, perm_shape});
+ }
+
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor<float>(input_, data);
+ }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+ int input_;
+ int perm_;
+ int output_;
+};
+
+TEST(NNAPIDelegate, TransposeSimpleTest) {
+ TransposeSimpleModel m({2, 3, 4}, {3}, {2, 0, 1});
+ m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
+ 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 3}));
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray({0, 4, 8, 12, 16, 20, 1, 5, 9, 13, 17, 21,
+ 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23}));
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/examples/android/app/README.md b/tensorflow/contrib/lite/examples/android/app/README.md
new file mode 100644
index 0000000000..8e12bd04dd
--- /dev/null
+++ b/tensorflow/contrib/lite/examples/android/app/README.md
@@ -0,0 +1,19 @@
+# TF Lite Android App Example
+
+## Building from Source with Bazel
+
+1. Follow the [Bazel steps for the TF Demo App](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#bazel).
+
+2. Build the app with Bazel. The demo needs C++11. We configure the fat_apk_cpu flag to package support for 4 hardware variants. You may replace it with --config=android_arm64 on a 64-bit device and --config=android_arm for 32-bit device:
+
+ ```shell
+ bazel build -c opt --cxxopt='--std=c++11' --fat_apk_cpu=x86,x86_64,arm64-v8a,armeabi-v7a \
+ //tensorflow/contrib/lite/examples/android:tflite_demo
+ ```
+
+3. Install the demo on a
+ [debug-enabled device](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#install):
+
+ ```shell
+ adb install bazel-bin/tensorflow/contrib/lite/examples/android/tflite_demo.apk
+ ```
diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/assets/pets_labels_list.txt b/tensorflow/contrib/lite/examples/android/app/src/main/assets/pets_labels_list.txt
new file mode 100644
index 0000000000..d581f733e4
--- /dev/null
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/assets/pets_labels_list.txt
@@ -0,0 +1,38 @@
+???
+Abyssinian
+american_bulldog
+american_pit_bull_terrier
+basset_hound
+beagle
+Bengal
+Birman
+Bombay
+boxer
+British_Shorthair
+chihuahua
+Egyptian_Mau
+english_cocker_spaniel
+english_setter
+german_shorthaired
+great_pyrenees
+havanese
+japanese_chin
+keeshond
+leonberger
+Maine_Coon
+miniature_pinscher
+newfoundland
+Persian
+pomeranian
+pug
+Ragdoll
+Russian_Blue
+saint_bernard
+samoyed
+scottish_terrier
+shiba_inu
+Siamese
+Sphynx
+staffordshire_bull_terrier
+wheaten_terrier
+yorkshire_terrier
diff --git a/tensorflow/contrib/lite/g3doc/benchmarks.md b/tensorflow/contrib/lite/g3doc/benchmarks.md
new file mode 100644
index 0000000000..96536cba27
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/benchmarks.md
@@ -0,0 +1,178 @@
+# Performance Benchmark numbers
+
+This document contains the performance benchmark numbers for running a few well
+known models on some Android and iOS devices.
+
+The benchmark numbers were generated by running the [TFLite benchmark
+binary](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark)
+on Android and running the [iOS benchmark
+app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark/ios)
+on iOS.
+
+# Android benchmarks
+
+When running Android benchmarks, the CPU affinity is set to use big cores on the
+device to reduce variance (see
+[details](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark#reducing-variance-between-runs-on-android)).
+
+Models are assumed to have been downloaded from the link, unzipped and pushed to
+`/data/local/tmp/tflite_models` folder. The benchmark binary is built according
+to instructions listed
+[here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark#on-android)
+and is assumed to have been pushed to `/data/local/tmp`.
+
+The following command was used to run the benchmark:
+
+```
+adb shell taskset ${CPU_MASK} /data/local/tmp/benchmark_model \
+ --num_threads=1 \
+ --graph=/data/local/tmp/tflite_models/${GRAPH} \
+ --warmup_runs=1 \
+ --num_runs=50 \
+ --use_nnapi=false
+```
+
+where `${GRAPH}` is the name of model and `${CPU_MASK}` is the CPU affinity
+chosen according to the following table:
+
+Device | CPU_MASK |
+-------| ----------
+Pixel 2 | f0 |
+Pixel xl | 0c |
+
+
+<table>
+ <thead>
+ <tr>
+ <th>Model Name</th>
+ <th>Device </th>
+ <th>Mean inference time (std dev)</th>
+ </tr>
+ </thead>
+ <tr>
+ <td rowspan = 2>
+ <a href="http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz">Mobilenet_1.0_224(float)</a>
+ </td>
+ <td>Pixel 2 </td>
+ <td>166.5 ms (2.6 ms)</td>
+ </tr>
+ <tr>
+ <td>Pixel xl </td>
+ <td>122.9 ms (1.8 ms) </td>
+ </tr>
+ <tr>
+ <td rowspan = 2>
+ <a href="http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224_quant.tgz">Mobilenet_1.0_224 (quant)</a>
+ </td>
+ <td>Pixel 2 </td>
+ <td>69.5 ms (0.9 ms)</td>
+ </tr>
+ <tr>
+ <td>Pixel xl </td>
+ <td>78.9 ms (2.2 ms) </td>
+ </tr>
+ <tr>
+ <td rowspan = 2>
+ <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_mobile_2018_04_27.tgz">NASNet mobile</a>
+ </td>
+ <td>Pixel 2 </td>
+ <td>273.8 ms (3.5 ms)</td>
+ </tr>
+ <tr>
+ <td>Pixel xl </td>
+ <td>210.8 ms (4.2 ms)</td>
+ </tr>
+ <tr>
+ <td rowspan = 2>
+ <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/squeezenet_2018_04_27.tgz">SqueezeNet</a>
+ </td>
+ <td>Pixel 2 </td>
+ <td>234.0 ms (2.1 ms)</td>
+ </tr>
+ <tr>
+ <td>Pixel xl </td>
+ <td>158.0 ms (2.1 ms)</td>
+ </tr>
+ <tr>
+ <td rowspan = 2>
+ <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_resnet_v2_2018_04_27.tgz">Inception_ResNet_V2</a>
+ </td>
+ <td>Pixel 2 </td>
+ <td>2846.0 ms (15.0 ms)</td>
+ </tr>
+ <tr>
+ <td>Pixel xl </td>
+ <td>1973.0 ms (15.0 ms) </td>
+ </tr>
+ <tr>
+ <td rowspan = 2>
+ <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz">Inception_V4</a>
+ </td>
+ <td>Pixel 2 </td>
+ <td>3180.0 ms (11.7 ms)</td>
+ </tr>
+ <tr>
+ <td>Pixel xl </td>
+ <td>2262.0 ms (21.0 ms) </td>
+ </tr>
+
+ </table>
+
+# iOS benchmarks
+
+For running iOS benchmarks, the [benchmark
+app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark/ios)
+was modified to include the appropriate model and `benchmark_params.json` was
+modified to set `num_threads` to 1.
+
+<table>
+ <thead>
+ <tr>
+ <th>Model Name</th>
+ <th>Device </th>
+ <th>Mean inference time (std dev)</th>
+ </tr>
+ </thead>
+ <tr>
+ <td>
+ <a href="http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz">Mobilenet_1.0_224(float)</a>
+ </td>
+ <td>iPhone 8 </td>
+ <td>32.2 ms (0.8 ms)</td>
+ </tr>
+ <tr>
+ <td>
+ <a href="http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224_quant.tgz)">Mobilenet_1.0_224 (quant)</a>
+ </td>
+ <td>iPhone 8 </td>
+ <td>24.4 ms (0.8 ms)</td>
+ </tr>
+ <tr>
+ <td>
+ <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_mobile_2018_04_27.tgz">NASNet mobile</a>
+ </td>
+ <td>iPhone 8 </td>
+ <td>60.3 ms (0.6 ms)</td>
+ </tr>
+ <tr>
+ <td>
+ <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/squeezenet_2018_04_27.tgz">SqueezeNet</a>
+ </td>
+ <td>iPhone 8 </td>
+ <td>44.3 (0.7 ms)</td>
+ </tr>
+ <tr>
+ <td>
+ <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_resnet_v2_2018_04_27.tgz">Inception_ResNet_V2</a>
+ </td>
+ <td>iPhone 8</td>
+ <td>562.4 ms (18.2 ms)</td>
+ </tr>
+ <tr>
+ <td>
+ <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz">Inception_V4</a>
+ </td>
+ <td>iPhone 8 </td>
+ <td>661.0 ms (29.2 ms)</td>
+ </tr>
+ </table>
diff --git a/tensorflow/contrib/lite/g3doc/models.md b/tensorflow/contrib/lite/g3doc/models.md
index c1c8ef049f..4e7d33a1b6 100644
--- a/tensorflow/contrib/lite/g3doc/models.md
+++ b/tensorflow/contrib/lite/g3doc/models.md
@@ -39,22 +39,22 @@ single thread large core.
Model Name | Paper_Model_Files | Model_Size | Top-1 Accuracy | Top-5 Accuracy | TF Lite Performance
------------------------ | :-------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | ------------------:
-Mobilenet_0.25_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_128_quant.tgz) | 0.5 Mb | 39.9% | 65.8% | 3.7 ms
-Mobilenet_0.25_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_160_quant.tgz) | 0.5 Mb | 43.5% | 69.1% | 5.5 ms
-Mobilenet_0.25_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_192_quant.tgz) | 0.5 Mb | 45.8% | 71.9% | 7.9 ms
-Mobilenet_0.25_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_224_quant.tgz) | 0.5 Mb | 48.2% | 73.8% | 10.4 ms
-Mobilenet_0.50_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_128_quant.tgz) | 1.4 Mb | 54.9% | 78.9% | 8.8 ms
-Mobilenet_0.50_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_160_quant.tgz) | 1.4 Mb | 57.7% | 81.3% | 13.0 ms
-Mobilenet_0.50_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_192_quant.tgz) | 1.4 Mb | 60.4% | 83.2% | 18.3 ms
-Mobilenet_0.50_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_224_quant.tgz) | 1.4 Mb | 62.2% | 84.5% | 24.7 ms
-Mobilenet_0.75_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_128_quant.tgz) | 2.6 Mb | 59.8% | 82.8% | 16.2 ms
-Mobilenet_0.75_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_160_quant.tgz) | 2.6 Mb | 63.9% | 85.5% | 24.3 ms
-Mobilenet_0.75_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_192_quant.tgz) | 2.6 Mb | 66.2% | 87.1% | 33.8 ms
-Mobilenet_0.75_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_224_quant.tgz) | 2.6 Mb | 67.9% | 88.1% | 45.4 ms
-Mobilenet_1.0_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_128_quant.tgz) | 4.3 Mb | 64.0% | 85.5% | 24.9 ms
-Mobilenet_1.0_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_160_quant.tgz) | 4.3 Mb | 67.3% | 87.7% | 37.4 ms
-Mobilenet_1.0_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_192_quant.tgz) | 4.3 Mb | 69.0% | 88.9% | 51.9 ms
-Mobilenet_1.0_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224_quant.tgz) | 4.3 Mb | 69.7% | 89.5% | 70.2 ms
+Mobilenet_0.25_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_07_12/mobilenet_v1_0.25_128_quant.tgz) | 0.5 Mb | 39.7% | 65.8% | 3.7 ms
+Mobilenet_0.25_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_07_12/mobilenet_v1_0.25_160_quant.tgz) | 0.5 Mb | 41.9% | 69.1% | 5.5 ms
+Mobilenet_0.25_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_07_12/mobilenet_v1_0.25_192_quant.tgz) | 0.5 Mb | 45.3% | 71.9% | 7.9 ms
+Mobilenet_0.25_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_07_12/mobilenet_v1_0.25_224_quant.tgz) | 0.5 Mb | 46.4% | 73.8% | 10.4 ms
+Mobilenet_0.50_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_07_12/mobilenet_v1_0.5_128_quant.tgz) | 1.4 Mb | 54.1% | 78.9% | 8.8 ms
+Mobilenet_0.50_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_07_12/mobilenet_v1_0.5_160_quant.tgz) | 1.4 Mb | 57.6% | 81.3% | 13.0 ms
+Mobilenet_0.50_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_07_12/mobilenet_v1_0.5_192_quant.tgz) | 1.4 Mb | 59.1% | 83.2% | 18.3 ms
+Mobilenet_0.50_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_07_12/mobilenet_v1_0.5_224_quant.tgz) | 1.4 Mb | 61.0% | 84.5% | 24.7 ms
+Mobilenet_0.75_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_07_12/mobilenet_v1_0.75_128_quant.tgz) | 2.6 Mb | 52.5% | 82.8% | 16.2 ms
+Mobilenet_0.75_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_07_12/mobilenet_v1_0.75_160_quant.tgz) | 2.6 Mb | 63.6% | 85.5% | 24.3 ms
+Mobilenet_0.75_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_07_12/mobilenet_v1_0.75_192_quant.tgz) | 2.6 Mb | 61.1% | 87.1% | 33.8 ms
+Mobilenet_0.75_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_07_12/mobilenet_v1_0.75_224_quant.tgz) | 2.6 Mb | 66.7% | 88.1% | 45.4 ms
+Mobilenet_1.0_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_07_12/mobilenet_v1_1.0_128_quant.tgz) | 4.3 Mb | 62.7% | 85.5% | 24.9 ms
+Mobilenet_1.0_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_07_12/mobilenet_v1_1.0_160_quant.tgz) | 4.3 Mb | 66.6% | 87.7% | 37.4 ms
+Mobilenet_1.0_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_07_12/mobilenet_v1_1.0_192_quant.tgz) | 4.3 Mb | 69.2% | 88.9% | 51.9 ms
+Mobilenet_1.0_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_07_12/mobilenet_v1_1.0_224_quant.tgz) | 4.3 Mb | 69.3% | 89.5% | 70.2 ms
## Other models
diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
index dcd17bbeab..49d00a66ba 100644
--- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
+++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
@@ -42,6 +42,7 @@ counterparts:
*as long as the input tensor is 4D (1 batch + 2 spatial + 1 other) and the
crops attribute is not used*
* [tf.exp](https://www.tensorflow.org/api_docs/python/tf/exp)
+* [tf.fake_quant*](https://www.tensorflow.org/api_docs/python/tf/fake_quant_with_min_max_args)
* [tf.matmul](https://www.tensorflow.org/api_docs/python/tf/matmul) - *as long
as the second argument is constant and transposition is not used*
* [tf.nn.avg_pool](https://www.tensorflow.org/api_docs/python/tf/nn/avg_pool)
@@ -790,6 +791,30 @@ Outputs {
}
```
+**ARG_MAX**
+
+```
+Inputs {
+ 0: a tensor
+ 1: a tensor
+}
+Outputs {
+ 0: A tensor of indices of maximum values.
+}
+```
+
+**ARG_MIN**
+
+```
+Inputs {
+ 0: a tensor
+ 1: a tensor
+}
+Outputs {
+ 0: A tensor of indices of minium values.
+}
+```
+
And these are TensorFlow Lite operations that are present but not ready for
custom models yet:
diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc
index 521216a4f1..d103786694 100644
--- a/tensorflow/contrib/lite/interpreter.cc
+++ b/tensorflow/contrib/lite/interpreter.cc
@@ -441,6 +441,13 @@ TfLiteStatus Interpreter::AllocateTensors() {
TF_LITE_ENSURE_STATUS(PrepareOpsAndTensors());
state_ = kStateInvokable;
+
+ // Reset the variable tensors to zero after (re)allocating the tensors.
+ // Developers shouldn't rely on the side effect of this function to reset
+ // variable tesnsors. They should call `ResetVariableTensorsToZero` directly
+ // instead.
+ ResetVariableTensorsToZero();
+
return kTfLiteOk;
}
@@ -565,6 +572,8 @@ TfLiteStatus Interpreter::PrepareOpsStartingAt(
nodes_and_registration_[node_index].second;
EnsureTensorsVectorCapacity();
if (OpPrepare(registration, &node) == kTfLiteError) {
+ context_.ReportError(&context_, "Node %d failed to prepare.\n",
+ node_index);
return kTfLiteError;
}
@@ -584,7 +593,7 @@ TfLiteStatus Interpreter::PrepareOpsAndTensors() {
if (!memory_planner_) {
memory_planner_.reset(new ArenaPlanner(
&context_, std::unique_ptr<GraphInfo>(new InterpreterInfo(this)),
- /*preserve_inputs=*/true));
+ /*preserve_inputs=*/true, /*preserve_intermediates*/ false));
memory_planner_->PlanAllocations();
}
@@ -665,6 +674,8 @@ TfLiteStatus Interpreter::Invoke() {
EnsureTensorsVectorCapacity();
tensor_resized_since_op_invoke_ = false;
if (OpInvoke(registration, &node) == kTfLiteError) {
+ context_.ReportError(&context_, "Node %d failed to invoke.\n",
+ node_index);
status = kTfLiteError;
}
diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h
index b69c50fbfc..1a1c3e272b 100644
--- a/tensorflow/contrib/lite/interpreter.h
+++ b/tensorflow/contrib/lite/interpreter.h
@@ -63,6 +63,10 @@ template <>
constexpr TfLiteType typeToTfLiteType<std::complex<float>>() {
return kTfLiteComplex64;
}
+template <>
+constexpr TfLiteType typeToTfLiteType<string>() {
+ return kTfLiteString;
+}
// Forward declare since NNAPIDelegate uses Interpreter.
class NNAPIDelegate;
diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc
index 4fa97512fc..10119903fe 100644
--- a/tensorflow/contrib/lite/interpreter_test.cc
+++ b/tensorflow/contrib/lite/interpreter_test.cc
@@ -57,6 +57,22 @@ TEST(BasicInterpreter, InvokeInvalidModel) {
ASSERT_EQ(interpreter.Invoke(), kTfLiteOk);
}
+TEST(BasicInterpreter, TestAllocateTensorsResetVariableTensors) {
+ Interpreter interpreter;
+ int tensor_index;
+ ASSERT_EQ(interpreter.AddTensors(1, &tensor_index), kTfLiteOk);
+ constexpr int kTensorSize = 16;
+ interpreter.SetTensorParametersReadWrite(tensor_index, kTfLiteFloat32, "",
+ {kTensorSize}, {}, true);
+ interpreter.SetVariables({tensor_index});
+ ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
+ TfLiteTensor* tensor = interpreter.tensor(tensor_index);
+ // Ensure that variable tensors are reset to zero.
+ for (int i = 0; i < kTensorSize; ++i) {
+ ASSERT_EQ(tensor->data.f[i], 0.0f);
+ }
+}
+
// Test size accessor functions.
TEST(BasicInterpreter, TestSizeFunctions) {
Interpreter interpreter;
diff --git a/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java b/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java
index 56f3e7604a..1587c3c56f 100644
--- a/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java
+++ b/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java
@@ -127,12 +127,8 @@ public final class OvicClassifierTest {
try {
testResult = classifier.classifyByteBuffer(testImage);
fail();
- } catch (RuntimeException e) {
- assertThat(e)
- .hasMessageThat()
- .contains(
- "Failed to get input dimensions. 0-th input should have 49152 bytes, "
- + "but found 150528 bytes.");
+ } catch (IllegalArgumentException e) {
+ // Success.
}
}
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java
index 75334cd96e..94a1ec65d6 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java
@@ -27,10 +27,7 @@ enum DataType {
UINT8(3),
/** 64-bit signed integer. */
- INT64(4),
-
- /** A {@link ByteBuffer}. */
- BYTEBUFFER(999);
+ INT64(4);
private final int value;
@@ -69,8 +66,6 @@ enum DataType {
return 1;
case INT64:
return 8;
- case BYTEBUFFER:
- return 1;
}
throw new IllegalArgumentException(
"DataType error: DataType " + this + " is not supported yet");
@@ -87,8 +82,6 @@ enum DataType {
return "byte";
case INT64:
return "long";
- case BYTEBUFFER:
- return "ByteBuffer";
}
throw new IllegalArgumentException(
"DataType error: DataType " + this + " is not supported yet");
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
index 4e22a68bf2..7002f82677 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
@@ -165,20 +165,7 @@ public final class Interpreter implements AutoCloseable {
if (wrapper == null) {
throw new IllegalStateException("Internal error: The Interpreter has already been closed.");
}
- Tensor[] tensors = wrapper.run(inputs);
- if (outputs == null || tensors == null || outputs.size() > tensors.length) {
- throw new IllegalArgumentException("Output error: Outputs do not match with model outputs.");
- }
- final int size = tensors.length;
- for (Integer idx : outputs.keySet()) {
- if (idx == null || idx < 0 || idx >= size) {
- throw new IllegalArgumentException(
- String.format(
- "Output error: Invalid index of output %d (should be in range [0, %d))",
- idx, size));
- }
- tensors[idx].copyTo(outputs.get(idx));
- }
+ wrapper.run(inputs, outputs);
}
/**
@@ -251,8 +238,10 @@ public final class Interpreter implements AutoCloseable {
/** Release resources associated with the {@code Interpreter}. */
@Override
public void close() {
- wrapper.close();
- wrapper = null;
+ if (wrapper != null) {
+ wrapper.close();
+ wrapper = null;
+ }
}
@Override
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
index 80de88b6a1..767a220f8c 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
@@ -15,10 +15,10 @@ limitations under the License.
package org.tensorflow.lite;
-import java.lang.reflect.Array;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
+import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
@@ -40,6 +40,8 @@ final class NativeInterpreterWrapper implements AutoCloseable {
modelHandle = createModel(modelPath, errorHandle);
interpreterHandle = createInterpreter(modelHandle, errorHandle, numThreads);
isMemoryAllocated = true;
+ inputTensors = new Tensor[getInputCount(interpreterHandle)];
+ outputTensors = new Tensor[getOutputCount(interpreterHandle)];
}
/**
@@ -72,6 +74,8 @@ final class NativeInterpreterWrapper implements AutoCloseable {
modelHandle = createModelWithBuffer(modelByteBuffer, errorHandle);
interpreterHandle = createInterpreter(modelHandle, errorHandle, numThreads);
isMemoryAllocated = true;
+ inputTensors = new Tensor[getInputCount(interpreterHandle)];
+ outputTensors = new Tensor[getOutputCount(interpreterHandle)];
}
/** Releases resources associated with this {@code NativeInterpreterWrapper}. */
@@ -85,75 +89,63 @@ final class NativeInterpreterWrapper implements AutoCloseable {
inputsIndexes = null;
outputsIndexes = null;
isMemoryAllocated = false;
+ Arrays.fill(inputTensors, null);
+ Arrays.fill(outputTensors, null);
}
/** Sets inputs, runs model inference and returns outputs. */
- Tensor[] run(Object[] inputs) {
+ void run(Object[] inputs, Map<Integer, Object> outputs) {
+ inferenceDurationNanoseconds = -1;
if (inputs == null || inputs.length == 0) {
throw new IllegalArgumentException("Input error: Inputs should not be null or empty.");
}
- int[] dataTypes = new int[inputs.length];
- Object[] sizes = new Object[inputs.length];
- int[] numsOfBytes = new int[inputs.length];
+ if (outputs == null || outputs.isEmpty()) {
+ throw new IllegalArgumentException("Input error: Outputs should not be null or empty.");
+ }
+
+ // TODO(b/80431971): Remove implicit resize after deprecating multi-dimensional array inputs.
+ // Rather than forcing an immediate resize + allocation if an input's shape differs, we first
+ // flush all resizes, avoiding redundant allocations.
for (int i = 0; i < inputs.length; ++i) {
- DataType dataType = dataTypeOf(inputs[i]);
- dataTypes[i] = dataType.getNumber();
- if (dataType == DataType.BYTEBUFFER) {
- ByteBuffer buffer = (ByteBuffer) inputs[i];
- if (buffer == null || !buffer.isDirect() || buffer.order() != ByteOrder.nativeOrder()) {
- throw new IllegalArgumentException(
- "Input error: ByteBuffer should be a direct ByteBuffer that uses "
- + "ByteOrder.nativeOrder().");
- }
- numsOfBytes[i] = buffer.limit();
- sizes[i] = getInputDims(interpreterHandle, i, numsOfBytes[i]);
- } else if (isNonEmptyArray(inputs[i])) {
- int[] dims = shapeOf(inputs[i]);
- sizes[i] = dims;
- numsOfBytes[i] = dataType.elemByteSize() * numElements(dims);
- } else {
- throw new IllegalArgumentException(
- String.format(
- "Input error: %d-th element of the %d inputs is not an array or a ByteBuffer.",
- i, inputs.length));
+ Tensor tensor = getInputTensor(i);
+ int[] newShape = tensor.getInputShapeIfDifferent(inputs[i]);
+ if (newShape != null) {
+ resizeInput(i, newShape);
}
}
- inferenceDurationNanoseconds = -1;
- long[] outputsHandles =
- run(
- interpreterHandle,
- errorHandle,
- sizes,
- dataTypes,
- numsOfBytes,
- inputs,
- this,
- isMemoryAllocated);
- if (outputsHandles == null || outputsHandles.length == 0) {
- throw new IllegalStateException("Internal error: Interpreter has no outputs.");
+
+ if (!isMemoryAllocated) {
+ allocateTensors(interpreterHandle, errorHandle);
+ isMemoryAllocated = true;
+ // Allocation can trigger dynamic resizing of output tensors, so clear the
+ // output tensor cache.
+ Arrays.fill(outputTensors, null);
}
- isMemoryAllocated = true;
- Tensor[] outputs = new Tensor[outputsHandles.length];
- for (int i = 0; i < outputsHandles.length; ++i) {
- outputs[i] = Tensor.fromHandle(outputsHandles[i]);
+
+ for (int i = 0; i < inputs.length; ++i) {
+ getInputTensor(i).setTo(inputs[i]);
+ }
+
+ long inferenceStartNanos = System.nanoTime();
+ run(interpreterHandle, errorHandle);
+ long inferenceDurationNanoseconds = System.nanoTime() - inferenceStartNanos;
+
+ for (Map.Entry<Integer, Object> output : outputs.entrySet()) {
+ getOutputTensor(output.getKey()).copyTo(output.getValue());
}
- return outputs;
+
+ // Only set if the entire operation succeeds.
+ this.inferenceDurationNanoseconds = inferenceDurationNanoseconds;
}
- private static native long[] run(
- long interpreterHandle,
- long errorHandle,
- Object[] sizes,
- int[] dtypes,
- int[] numsOfBytes,
- Object[] values,
- NativeInterpreterWrapper wrapper,
- boolean memoryAllocated);
+ private static native boolean run(long interpreterHandle, long errorHandle);
/** Resizes dimensions of a specific input. */
void resizeInput(int idx, int[] dims) {
if (resizeInput(interpreterHandle, errorHandle, idx, dims)) {
isMemoryAllocated = false;
+ // Resizing will invalidate the Tensor's shape, so invalidate the Tensor handle.
+ inputTensors[idx] = null;
}
}
@@ -212,78 +204,6 @@ final class NativeInterpreterWrapper implements AutoCloseable {
}
}
- static int numElements(int[] shape) {
- if (shape == null) {
- return 0;
- }
- int n = 1;
- for (int i = 0; i < shape.length; i++) {
- n *= shape[i];
- }
- return n;
- }
-
- static boolean isNonEmptyArray(Object o) {
- return (o != null && o.getClass().isArray() && Array.getLength(o) != 0);
- }
-
- /** Returns the type of the data. */
- static DataType dataTypeOf(Object o) {
- if (o != null) {
- Class<?> c = o.getClass();
- while (c.isArray()) {
- c = c.getComponentType();
- }
- if (float.class.equals(c)) {
- return DataType.FLOAT32;
- } else if (int.class.equals(c)) {
- return DataType.INT32;
- } else if (byte.class.equals(c)) {
- return DataType.UINT8;
- } else if (long.class.equals(c)) {
- return DataType.INT64;
- } else if (ByteBuffer.class.isInstance(o)) {
- return DataType.BYTEBUFFER;
- }
- }
- throw new IllegalArgumentException(
- "DataType error: cannot resolve DataType of " + o.getClass().getName());
- }
-
- /** Returns the shape of an object as an int array. */
- static int[] shapeOf(Object o) {
- int size = numDimensions(o);
- int[] dimensions = new int[size];
- fillShape(o, 0, dimensions);
- return dimensions;
- }
-
- static int numDimensions(Object o) {
- if (o == null || !o.getClass().isArray()) {
- return 0;
- }
- if (Array.getLength(o) == 0) {
- throw new IllegalArgumentException("Array lengths cannot be 0.");
- }
- return 1 + numDimensions(Array.get(o, 0));
- }
-
- static void fillShape(Object o, int dim, int[] shape) {
- if (shape == null || dim == shape.length) {
- return;
- }
- final int len = Array.getLength(o);
- if (shape[dim] == 0) {
- shape[dim] = len;
- } else if (shape[dim] != len) {
- throw new IllegalArgumentException(
- String.format("Mismatched lengths (%d and %d) in dimension %d", shape[dim], len, dim));
- }
- for (int i = 0; i < len; ++i) {
- fillShape(Array.get(o, i), dim + 1, shape);
- }
- }
-
/**
* Gets the last inference duration in nanoseconds. It returns null if there is no previous
* inference run or the last inference run failed.
@@ -293,40 +213,55 @@ final class NativeInterpreterWrapper implements AutoCloseable {
}
/**
- * Gets the dimensions of an input. It throws IllegalArgumentException if input index is invalid.
+ * Gets the quantization zero point of an output.
+ *
+ * @throws IllegalArgumentException if the output index is invalid.
*/
- int[] getInputDims(int index) {
- return getInputDims(interpreterHandle, index, -1);
+ int getOutputQuantizationZeroPoint(int index) {
+ return getOutputQuantizationZeroPoint(interpreterHandle, index);
}
/**
- * Gets the dimensions of an input. If numBytes >= 0, it will check whether num of bytes match the
- * input.
+ * Gets the quantization scale of an output.
+ *
+ * @throws IllegalArgumentException if the output index is invalid.
*/
- private static native int[] getInputDims(long interpreterHandle, int inputIdx, int numBytes);
-
- /** Gets the type of an output. It throws IllegalArgumentException if output index is invalid. */
- String getOutputDataType(int index) {
- int type = getOutputDataType(interpreterHandle, index);
- return DataType.fromNumber(type).toStringName();
+ float getOutputQuantizationScale(int index) {
+ return getOutputQuantizationScale(interpreterHandle, index);
}
/**
- * Gets the quantization zero point of an output.
+ * Gets the input {@link Tensor} for the provided input index.
*
- * @throws IllegalArgumentExeption if the output index is invalid.
+ * @throws IllegalArgumentException if the input index is invalid.
*/
- int getOutputQuantizationZeroPoint(int index) {
- return getOutputQuantizationZeroPoint(interpreterHandle, index);
+ Tensor getInputTensor(int index) {
+ if (index < 0 || index >= inputTensors.length) {
+ throw new IllegalArgumentException("Invalid input Tensor index: " + index);
+ }
+ Tensor inputTensor = inputTensors[index];
+ if (inputTensor == null) {
+ inputTensor =
+ inputTensors[index] = Tensor.fromHandle(getInputTensor(interpreterHandle, index));
+ }
+ return inputTensor;
}
/**
- * Gets the quantization scale of an output.
+ * Gets the output {@link Tensor} for the provided output index.
*
- * @throws IllegalArgumentExeption if the output index is invalid.
+ * @throws IllegalArgumentException if the output index is invalid.
*/
- float getOutputQuantizationScale(int index) {
- return getOutputQuantizationScale(interpreterHandle, index);
+ Tensor getOutputTensor(int index) {
+ if (index < 0 || index >= outputTensors.length) {
+ throw new IllegalArgumentException("Invalid output Tensor index: " + index);
+ }
+ Tensor outputTensor = outputTensors[index];
+ if (outputTensor == null) {
+ outputTensor =
+ outputTensors[index] = Tensor.fromHandle(getOutputTensor(interpreterHandle, index));
+ }
+ return outputTensor;
}
private static native int getOutputDataType(long interpreterHandle, int outputIdx);
@@ -343,18 +278,30 @@ final class NativeInterpreterWrapper implements AutoCloseable {
private long modelHandle;
- private int inputSize;
-
private long inferenceDurationNanoseconds = -1;
private ByteBuffer modelByteBuffer;
+ // Lazily constructed maps of input and output names to input and output Tensor indexes.
private Map<String, Integer> inputsIndexes;
-
private Map<String, Integer> outputsIndexes;
+ // Lazily constructed and populated arrays of input and output Tensor wrappers.
+ private final Tensor[] inputTensors;
+ private final Tensor[] outputTensors;
+
private boolean isMemoryAllocated = false;
+ private static native long allocateTensors(long interpreterHandle, long errorHandle);
+
+ private static native long getInputTensor(long interpreterHandle, int inputIdx);
+
+ private static native long getOutputTensor(long interpreterHandle, int outputIdx);
+
+ private static native int getInputCount(long interpreterHandle);
+
+ private static native int getOutputCount(long interpreterHandle);
+
private static native String[] getInputNames(long interpreterHandle);
private static native String[] getOutputNames(long interpreterHandle);
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java
index b2a3e04c55..2403570c52 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java
@@ -15,6 +15,7 @@ limitations under the License.
package org.tensorflow.lite;
+import java.lang.reflect.Array;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Arrays;
@@ -31,43 +32,179 @@ final class Tensor {
return new Tensor(nativeHandle);
}
+ /** Returns the {@link DataType} of elements stored in the Tensor. */
+ public DataType dataType() {
+ return dtype;
+ }
+
+ /** Returns the size, in bytes, of the tensor data. */
+ public int numBytes() {
+ return numBytes(nativeHandle);
+ }
+
+ /**
+ * Returns the <a href="https://www.tensorflow.org/resources/dims_types.html#shape">shape</a> of
+ * the Tensor, i.e., the sizes of each dimension.
+ *
+ * @return an array where the i-th element is the size of the i-th dimension of the tensor.
+ */
+ public int[] shape() {
+ return shapeCopy;
+ }
+
+ /**
+ * Copies the contents of the provided {@code src} object to the Tensor.
+ *
+ * <p>The {@code src} should either be a (multi-dimensional) array with a shape matching that of
+ * this tensor, or a {@link ByteByffer} of compatible primitive type with a matching flat size.
+ *
+ * @throws IllegalArgumentException if the tensor is a scalar or if {@code src} is not compatible
+ * with the tensor (for example, mismatched data types or shapes).
+ */
+ void setTo(Object src) {
+ throwExceptionIfTypeIsIncompatible(src);
+ if (isByteBuffer(src)) {
+ ByteBuffer srcBuffer = (ByteBuffer) src;
+ // For direct ByteBuffer instances we support zero-copy. Note that this assumes the caller
+ // retains ownership of the source buffer until inference has completed.
+ if (srcBuffer.isDirect() && srcBuffer.order() == ByteOrder.nativeOrder()) {
+ writeDirectBuffer(nativeHandle, srcBuffer);
+ } else {
+ buffer().put(srcBuffer);
+ }
+ return;
+ }
+ writeMultiDimensionalArray(nativeHandle, src);
+ }
+
/**
* Copies the contents of the tensor to {@code dst} and returns {@code dst}.
*
* @param dst the destination buffer, either an explicitly-typed array or a {@link ByteBuffer}.
* @throws IllegalArgumentException if {@code dst} is not compatible with the tensor (for example,
* mismatched data types or shapes).
- * @throws BufferOverflowException If {@code dst} is a ByteBuffer with insufficient space for the
- * data in this tensor.
*/
- <T> T copyTo(T dst) {
+ Object copyTo(Object dst) {
+ throwExceptionIfTypeIsIncompatible(dst);
if (dst instanceof ByteBuffer) {
ByteBuffer dstByteBuffer = (ByteBuffer) dst;
dstByteBuffer.put(buffer());
return dst;
}
- if (NativeInterpreterWrapper.dataTypeOf(dst) != dtype) {
+ readMultiDimensionalArray(nativeHandle, dst);
+ return dst;
+ }
+
+ /** Returns the provided buffer's shape if specified and different from this Tensor's shape. */
+ // TODO(b/80431971): Remove this method after deprecating multi-dimensional array inputs.
+ int[] getInputShapeIfDifferent(Object input) {
+ // Implicit resizes based on ByteBuffer capacity isn't supported, so short-circuit that path.
+ // The ByteBuffer's size will be validated against this Tensor's size in {@link #setTo(Object)}.
+ if (isByteBuffer(input)) {
+ return null;
+ }
+ int[] inputShape = shapeOf(input);
+ if (Arrays.equals(shapeCopy, inputShape)) {
+ return null;
+ }
+ return inputShape;
+ }
+
+ /** Returns the type of the data. */
+ static DataType dataTypeOf(Object o) {
+ if (o != null) {
+ Class<?> c = o.getClass();
+ while (c.isArray()) {
+ c = c.getComponentType();
+ }
+ if (float.class.equals(c)) {
+ return DataType.FLOAT32;
+ } else if (int.class.equals(c)) {
+ return DataType.INT32;
+ } else if (byte.class.equals(c)) {
+ return DataType.UINT8;
+ } else if (long.class.equals(c)) {
+ return DataType.INT64;
+ }
+ }
+ throw new IllegalArgumentException(
+ "DataType error: cannot resolve DataType of " + o.getClass().getName());
+ }
+
+ /** Returns the shape of an object as an int array. */
+ static int[] shapeOf(Object o) {
+ int size = numDimensions(o);
+ int[] dimensions = new int[size];
+ fillShape(o, 0, dimensions);
+ return dimensions;
+ }
+
+ /** Returns the number of dimensions of a multi-dimensional array, otherwise 0. */
+ static int numDimensions(Object o) {
+ if (o == null || !o.getClass().isArray()) {
+ return 0;
+ }
+ if (Array.getLength(o) == 0) {
+ throw new IllegalArgumentException("Array lengths cannot be 0.");
+ }
+ return 1 + numDimensions(Array.get(o, 0));
+ }
+
+ /** Recursively populates the shape dimensions for a given (multi-dimensional) array. */
+ static void fillShape(Object o, int dim, int[] shape) {
+ if (shape == null || dim == shape.length) {
+ return;
+ }
+ final int len = Array.getLength(o);
+ if (shape[dim] == 0) {
+ shape[dim] = len;
+ } else if (shape[dim] != len) {
+ throw new IllegalArgumentException(
+ String.format("Mismatched lengths (%d and %d) in dimension %d", shape[dim], len, dim));
+ }
+ for (int i = 0; i < len; ++i) {
+ fillShape(Array.get(o, i), dim + 1, shape);
+ }
+ }
+
+ private void throwExceptionIfTypeIsIncompatible(Object o) {
+ if (isByteBuffer(o)) {
+ ByteBuffer oBuffer = (ByteBuffer) o;
+ if (oBuffer.capacity() != numBytes()) {
+ throw new IllegalArgumentException(
+ String.format(
+ "Cannot convert between a TensorFlowLite buffer with %d bytes and a "
+ + "ByteBuffer with %d bytes.",
+ numBytes(), oBuffer.capacity()));
+ }
+ return;
+ }
+ DataType oType = dataTypeOf(o);
+ if (oType != dtype) {
throw new IllegalArgumentException(
String.format(
- "Output error: Cannot convert an TensorFlowLite tensor with type %s to a Java "
- + "object of type %s (which is compatible with the TensorFlowLite type %s)",
- dtype, dst.getClass().getName(), NativeInterpreterWrapper.dataTypeOf(dst)));
+ "Cannot convert between a TensorFlowLite tensor with type %s and a Java "
+ + "object of type %s (which is compatible with the TensorFlowLite type %s).",
+ dtype, o.getClass().getName(), oType));
}
- int[] dstShape = NativeInterpreterWrapper.shapeOf(dst);
- if (!Arrays.equals(dstShape, shapeCopy)) {
+
+ int[] oShape = shapeOf(o);
+ if (!Arrays.equals(oShape, shapeCopy)) {
throw new IllegalArgumentException(
String.format(
- "Output error: Shape of output target %s does not match with the shape of the "
- + "Tensor %s.",
- Arrays.toString(dstShape), Arrays.toString(shapeCopy)));
+ "Cannot copy between a TensorFlowLite tensor with shape %s and a Java object "
+ + "with shape %s.",
+ Arrays.toString(shapeCopy), Arrays.toString(oShape)));
}
- readMultiDimensionalArray(nativeHandle, dst);
- return dst;
}
- final long nativeHandle;
- final DataType dtype;
- final int[] shapeCopy;
+ private static boolean isByteBuffer(Object o) {
+ return o instanceof ByteBuffer;
+ }
+
+ private final long nativeHandle;
+ private final DataType dtype;
+ private final int[] shapeCopy;
private Tensor(long nativeHandle) {
this.nativeHandle = nativeHandle;
@@ -81,11 +218,17 @@ final class Tensor {
private static native ByteBuffer buffer(long handle);
+ private static native void writeDirectBuffer(long handle, ByteBuffer src);
+
private static native int dtype(long handle);
private static native int[] shape(long handle);
- private static native void readMultiDimensionalArray(long handle, Object value);
+ private static native int numBytes(long handle);
+
+ private static native void readMultiDimensionalArray(long handle, Object dst);
+
+ private static native void writeMultiDimensionalArray(long handle, Object src);
static {
TensorFlowLite.init();
diff --git a/tensorflow/contrib/lite/java/src/main/native/BUILD b/tensorflow/contrib/lite/java/src/main/native/BUILD
index 4399ed2025..4b4e1c21d8 100644
--- a/tensorflow/contrib/lite/java/src/main/native/BUILD
+++ b/tensorflow/contrib/lite/java/src/main/native/BUILD
@@ -11,7 +11,6 @@ licenses(["notice"]) # Apache 2.0
cc_library(
name = "native_framework_only",
srcs = [
- "duration_utils_jni.cc",
"exception_jni.cc",
"nativeinterpreterwrapper_jni.cc",
"tensor_jni.cc",
diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
index 31f7b58fbc..e2c1edd9af 100644
--- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
+++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
@@ -16,9 +16,6 @@ limitations under the License.
#include "tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h"
namespace {
-const int kByteBufferValue = 999;
-const int kBufferSize = 256;
-
tflite::Interpreter* convertLongToInterpreter(JNIEnv* env, jlong handle) {
if (handle == 0) {
throwException(env, kIllegalArgumentException,
@@ -62,22 +59,6 @@ std::vector<int> convertJIntArrayToVector(JNIEnv* env, jintArray inputs) {
return outputs;
}
-bool isByteBuffer(jint data_type) { return data_type == kByteBufferValue; }
-
-TfLiteType resolveDataType(jint data_type) {
- switch (data_type) {
- case 1:
- return kTfLiteFloat32;
- case 2:
- return kTfLiteInt32;
- case 3:
- return kTfLiteUInt8;
- case 4:
- return kTfLiteInt64;
- default:
- return kTfLiteNoType;
- }
-}
int getDataType(TfLiteType data_type) {
switch (data_type) {
@@ -108,64 +89,6 @@ void printDims(char* buffer, int max_size, int* dims, int num_dims) {
}
}
-TfLiteStatus checkInputs(JNIEnv* env, tflite::Interpreter* interpreter,
- const int input_size, jintArray data_types,
- jintArray nums_of_bytes, jobjectArray values,
- jobjectArray sizes) {
- if (input_size != interpreter->inputs().size()) {
- throwException(env, kIllegalArgumentException,
- "Input error: Expected num of inputs is %d but got %d",
- interpreter->inputs().size(), input_size);
- return kTfLiteError;
- }
- if (input_size != env->GetArrayLength(data_types) ||
- input_size != env->GetArrayLength(nums_of_bytes) ||
- input_size != env->GetArrayLength(values)) {
- throwException(env, kIllegalArgumentException,
- "Internal error: Arrays in arguments should be of the same "
- "length, but got %d sizes, %d data_types, %d nums_of_bytes, "
- "and %d values",
- input_size, env->GetArrayLength(data_types),
- env->GetArrayLength(nums_of_bytes),
- env->GetArrayLength(values));
- return kTfLiteError;
- }
- for (int i = 0; i < input_size; ++i) {
- int input_idx = interpreter->inputs()[i];
- TfLiteTensor* target = interpreter->tensor(input_idx);
- jintArray dims =
- static_cast<jintArray>(env->GetObjectArrayElement(sizes, i));
- int num_dims = static_cast<int>(env->GetArrayLength(dims));
- if (target->dims->size != num_dims) {
- throwException(env, kIllegalArgumentException,
- "Input error: %d-th input should have %d dimensions, but "
- "found %d dimensions",
- i, target->dims->size, num_dims);
- return kTfLiteError;
- }
- jint* ptr = env->GetIntArrayElements(dims, nullptr);
- for (int j = 1; j < num_dims; ++j) {
- if (target->dims->data[j] != ptr[j]) {
- std::unique_ptr<char[]> expected_dims(new char[kBufferSize]);
- std::unique_ptr<char[]> obtained_dims(new char[kBufferSize]);
- printDims(expected_dims.get(), kBufferSize, target->dims->data,
- num_dims);
- printDims(obtained_dims.get(), kBufferSize, ptr, num_dims);
- throwException(env, kIllegalArgumentException,
- "Input error: %d-th input dimension should be [%s], but "
- "found [%s]",
- i, expected_dims.get(), obtained_dims.get());
- env->ReleaseIntArrayElements(dims, ptr, JNI_ABORT);
- return kTfLiteError;
- }
- }
- env->ReleaseIntArrayElements(dims, ptr, JNI_ABORT);
- env->DeleteLocalRef(dims);
- if (env->ExceptionCheck()) return kTfLiteError;
- }
- return kTfLiteOk;
-}
-
// Checks whether there is any difference between dimensions of a tensor and a
// given dimensions. Returns true if there is difference, else false.
bool areDimsDifferent(JNIEnv* env, TfLiteTensor* tensor, jintArray dims) {
@@ -188,74 +111,6 @@ bool areDimsDifferent(JNIEnv* env, TfLiteTensor* tensor, jintArray dims) {
return false;
}
-bool areInputDimensionsTheSame(JNIEnv* env, tflite::Interpreter* interpreter,
- int input_size, jobjectArray sizes) {
- if (interpreter->inputs().size() != input_size) {
- return false;
- }
- for (int i = 0; i < input_size; ++i) {
- int input_idx = interpreter->inputs()[i];
- jintArray dims =
- static_cast<jintArray>(env->GetObjectArrayElement(sizes, i));
- TfLiteTensor* target = interpreter->tensor(input_idx);
- if (areDimsDifferent(env, target, dims)) return false;
- env->DeleteLocalRef(dims);
- if (env->ExceptionCheck()) return false;
- }
- return true;
-}
-
-TfLiteStatus resizeInputs(JNIEnv* env, tflite::Interpreter* interpreter,
- int input_size, jobjectArray sizes) {
- for (int i = 0; i < input_size; ++i) {
- int input_idx = interpreter->inputs()[i];
- jintArray dims =
- static_cast<jintArray>(env->GetObjectArrayElement(sizes, i));
- TfLiteStatus status = interpreter->ResizeInputTensor(
- input_idx, convertJIntArrayToVector(env, dims));
- if (status != kTfLiteOk) {
- return status;
- }
- env->DeleteLocalRef(dims);
- if (env->ExceptionCheck()) return kTfLiteError;
- }
- return kTfLiteOk;
-}
-
-TfLiteStatus setInputs(JNIEnv* env, tflite::Interpreter* interpreter,
- int input_size, jintArray data_types,
- jintArray nums_of_bytes, jobjectArray values) {
- jint* data_type = env->GetIntArrayElements(data_types, nullptr);
- jint* num_bytes = env->GetIntArrayElements(nums_of_bytes, nullptr);
- for (int i = 0; i < input_size; ++i) {
- int input_idx = interpreter->inputs()[i];
- TfLiteTensor* target = interpreter->tensor(input_idx);
- jobject value = env->GetObjectArrayElement(values, i);
- bool is_byte_buffer = isByteBuffer(data_type[i]);
- if (is_byte_buffer) {
- writeByteBuffer(env, value, &(target->data.raw),
- static_cast<int>(num_bytes[i]));
- } else {
- TfLiteType type = resolveDataType(data_type[i]);
- if (type != target->type) {
- throwException(env, kIllegalArgumentException,
- "Input error: DataType (%d) of input data does not "
- "match with the DataType (%d) of model inputs.",
- type, target->type);
- return kTfLiteError;
- }
- writeMultiDimensionalArray(env, value, target->type, target->dims->size,
- &(target->data.raw),
- static_cast<int>(num_bytes[i]));
- }
- env->DeleteLocalRef(value);
- if (env->ExceptionCheck()) return kTfLiteError;
- }
- env->ReleaseIntArrayElements(data_types, data_type, JNI_ABORT);
- env->ReleaseIntArrayElements(nums_of_bytes, num_bytes, JNI_ABORT);
- return kTfLiteOk;
-}
-
// TODO(yichengfan): evaluate the benefit to use tflite verifier.
bool VerifyModel(const void* buf, size_t len) {
flatbuffers::Verifier verifier(static_cast<const uint8_t*>(buf), len);
@@ -287,6 +142,63 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputNames(JNIEnv* env,
return names;
}
+JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_allocateTensors(
+ JNIEnv* env, jclass clazz, jlong handle, jlong error_handle) {
+ tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+ if (interpreter == nullptr) return;
+ BufferErrorReporter* error_reporter =
+ convertLongToErrorReporter(env, error_handle);
+ if (error_reporter == nullptr) return;
+
+ if (interpreter->AllocateTensors() != kTfLiteOk) {
+ throwException(env, kNullPointerException,
+ "Internal error: Cannot allocate memory for the interpreter:"
+ " %s",
+ error_reporter->CachedErrorMessage());
+ }
+}
+
+JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensor(JNIEnv* env,
+ jclass clazz,
+ jlong handle,
+ jint index) {
+ tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+ if (interpreter == nullptr) return 0;
+ return reinterpret_cast<jlong>(
+ interpreter->tensor(interpreter->inputs()[index]));
+}
+
+JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensor(JNIEnv* env,
+ jclass clazz,
+ jlong handle,
+ jint index) {
+ tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+ if (interpreter == nullptr) return 0;
+ return reinterpret_cast<jlong>(
+ interpreter->tensor(interpreter->outputs()[index]));
+}
+
+JNIEXPORT jint JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputCount(JNIEnv* env,
+ jclass clazz,
+ jlong handle) {
+ tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+ if (interpreter == nullptr) return 0;
+ return static_cast<jint>(interpreter->inputs().size());
+}
+
+JNIEXPORT jint JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputCount(JNIEnv* env,
+ jclass clazz,
+ jlong handle) {
+ tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+ if (interpreter == nullptr) return 0;
+ return static_cast<jint>(interpreter->outputs().size());
+}
+
JNIEXPORT jobjectArray JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputNames(JNIEnv* env,
jclass clazz,
@@ -434,114 +346,21 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter(
}
// Sets inputs, runs inference, and returns outputs as long handles.
-JNIEXPORT jlongArray JNICALL
-Java_org_tensorflow_lite_NativeInterpreterWrapper_run(
- JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle,
- jobjectArray sizes, jintArray data_types, jintArray nums_of_bytes,
- jobjectArray values, jobject wrapper, jboolean memory_allocated) {
+JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_run(
+ JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle) {
tflite::Interpreter* interpreter =
convertLongToInterpreter(env, interpreter_handle);
- if (interpreter == nullptr) return nullptr;
+ if (interpreter == nullptr) return;
BufferErrorReporter* error_reporter =
convertLongToErrorReporter(env, error_handle);
- if (error_reporter == nullptr) return nullptr;
- const int input_size = env->GetArrayLength(sizes);
- // validates inputs
- TfLiteStatus status = checkInputs(env, interpreter, input_size, data_types,
- nums_of_bytes, values, sizes);
- if (status != kTfLiteOk) return nullptr;
- if (!memory_allocated ||
- !areInputDimensionsTheSame(env, interpreter, input_size, sizes)) {
- // resizes inputs
- status = resizeInputs(env, interpreter, input_size, sizes);
- if (status != kTfLiteOk) {
- throwException(env, kNullPointerException,
- "Internal error: Can not resize the input: %s",
- error_reporter->CachedErrorMessage());
- return nullptr;
- }
- // allocates memory
- status = interpreter->AllocateTensors();
- if (status != kTfLiteOk) {
- throwException(env, kNullPointerException,
- "Internal error: Can not allocate memory for the given "
- "inputs: %s",
- error_reporter->CachedErrorMessage());
- return nullptr;
- }
- }
- // sets inputs
- status = setInputs(env, interpreter, input_size, data_types, nums_of_bytes,
- values);
- if (status != kTfLiteOk) return nullptr;
- timespec beforeInference = ::tflite::getCurrentTime();
- // runs inference
+ if (error_reporter == nullptr) return;
+
if (interpreter->Invoke() != kTfLiteOk) {
throwException(env, kIllegalArgumentException,
"Internal error: Failed to run on the given Interpreter: %s",
error_reporter->CachedErrorMessage());
- return nullptr;
- }
- timespec afterInference = ::tflite::getCurrentTime();
- jclass wrapper_clazz = env->GetObjectClass(wrapper);
- jfieldID fid =
- env->GetFieldID(wrapper_clazz, "inferenceDurationNanoseconds", "J");
- if (env->ExceptionCheck()) {
- env->ExceptionClear();
- } else if (fid != nullptr) {
- env->SetLongField(
- wrapper, fid,
- ::tflite::timespec_diff_nanoseconds(&beforeInference, &afterInference));
- }
- // returns outputs
- const std::vector<int>& results = interpreter->outputs();
- if (results.empty()) {
- throwException(
- env, kIllegalArgumentException,
- "Internal error: The Interpreter does not have any outputs.");
- return nullptr;
- }
- jlongArray outputs = env->NewLongArray(results.size());
- size_t size = results.size();
- for (int i = 0; i < size; ++i) {
- TfLiteTensor* source = interpreter->tensor(results[i]);
- jlong output = reinterpret_cast<jlong>(source);
- env->SetLongArrayRegion(outputs, i, 1, &output);
- }
- return outputs;
-}
-
-JNIEXPORT jintArray JNICALL
-Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputDims(
- JNIEnv* env, jclass clazz, jlong handle, jint input_idx, jint num_bytes) {
- tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
- if (interpreter == nullptr) return nullptr;
- const int idx = static_cast<int>(input_idx);
- if (input_idx < 0 || input_idx >= interpreter->inputs().size()) {
- throwException(env, kIllegalArgumentException,
- "Input error: Out of range: Failed to get %d-th input out of"
- " %d inputs",
- input_idx, interpreter->inputs().size());
- return nullptr;
- }
- TfLiteTensor* target = interpreter->tensor(interpreter->inputs()[idx]);
- int size = target->dims->size;
- if (num_bytes >= 0) { // verifies num of bytes matches if num_bytes if valid.
- int expected_num_bytes = elementByteSize(target->type);
- for (int i = 0; i < size; ++i) {
- expected_num_bytes *= target->dims->data[i];
- }
- if (num_bytes != expected_num_bytes) {
- throwException(env, kIllegalArgumentException,
- "Input error: Failed to get input dimensions. %d-th input "
- "should have %d bytes, but found %d bytes.",
- idx, expected_num_bytes, num_bytes);
- return nullptr;
- }
+ return;
}
- jintArray outputs = env->NewIntArray(size);
- env->SetIntArrayRegion(outputs, 0, size, &(target->dims->data[0]));
- return outputs;
}
JNIEXPORT jint JNICALL
diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
index 128ece4981..618fba480e 100644
--- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
+++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
@@ -29,9 +29,6 @@ limitations under the License.
namespace tflite {
// This is to be provided at link-time by a library.
extern std::unique_ptr<OpResolver> CreateOpResolver();
-extern timespec getCurrentTime();
-extern jlong timespec_diff_nanoseconds(struct timespec* start,
- struct timespec* stop);
} // namespace tflite
#ifdef __cplusplus
@@ -40,6 +37,57 @@ extern "C" {
/*
* Class: org_tensorflow_lite_NativeInterpreterWrapper
+ * Method: allocateTensors
+ * Signature: (JJ)V
+ */
+JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_allocateTensors(
+ JNIEnv* env, jclass clazz, jlong handle, jlong error_handle);
+
+/*
+ * Class: org_tensorflow_lite_NativeInterpreterWrapper
+ * Method: getInputTensor
+ * Signature: (JI)J
+ */
+JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensor(JNIEnv* env,
+ jclass clazz,
+ jlong handle,
+ jint index);
+
+/*
+ * Class: org_tensorflow_lite_NativeInterpreterWrapper
+ * Method: getOutputTensor
+ * Signature: (JI)J
+ */
+JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensor(JNIEnv* env,
+ jclass clazz,
+ jlong handle,
+ jint index);
+
+/*
+ * Class: org_tensorflow_lite_NativeInterpreterWrapper
+ * Method: getInputCount
+ * Signature: (J)I
+ */
+JNIEXPORT jint JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputCount(JNIEnv* env,
+ jclass clazz,
+ jlong handle);
+
+/*
+ * Class: org_tensorflow_lite_NativeInterpreterWrapper
+ * Method: getOutputCount
+ * Signature: (J)I
+ */
+JNIEXPORT jint JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputCount(JNIEnv* env,
+ jclass clazz,
+ jlong handle);
+
+/*
+ * Class: org_tensorflow_lite_NativeInterpreterWrapper
* Method:
* Signature: (J)[Ljava/lang/Object;
*/
@@ -118,28 +166,11 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter(
/*
* Class: org_tensorflow_lite_NativeInterpreterWrapper
- * Method:
- * Signature:
- * (JJ[Ljava/lang/Object;[I[I[Ljava/lang/Object;Ljava/lang/Object;Z)[J
- */
-JNIEXPORT jlongArray JNICALL
-Java_org_tensorflow_lite_NativeInterpreterWrapper_run(
- JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle,
- jobjectArray sizes, jintArray data_types, jintArray nums_of_bytes,
- jobjectArray values, jobject wrapper, jboolean memory_allocated);
-
-/*
- * Class: org_tensorflow_lite_NativeInterpreterWrapper
- * Method:
- * Signature: (JII)[I
- *
- * Gets input dimensions. If num_bytes is non-negative, it will check whether
- * num_bytes matches num of bytes required by the input, and return null and
- * throw IllegalArgumentException if not.
+ * Method: run
+ * Signature: (JJ)V
*/
-JNIEXPORT jintArray JNICALL
-Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputDims(
- JNIEnv* env, jclass clazz, jlong handle, jint input_idx, jint num_bytes);
+JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_run(
+ JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle);
/*
* Class: org_tensorflow_lite_NativeInterpreterWrapper
diff --git a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc
index 08b4d04280..7ff96a3172 100644
--- a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc
+++ b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc
@@ -29,6 +29,35 @@ TfLiteTensor* convertLongToTensor(JNIEnv* env, jlong handle) {
return reinterpret_cast<TfLiteTensor*>(handle);
}
+size_t elementByteSize(TfLiteType data_type) {
+ // The code in this file makes the assumption that the
+ // TensorFlow TF_DataTypes and the Java primitive types
+ // have the same byte sizes. Validate that:
+ switch (data_type) {
+ case kTfLiteFloat32:
+ static_assert(sizeof(jfloat) == 4,
+ "Interal error: Java float not compatible with "
+ "kTfLiteFloat");
+ return 4;
+ case kTfLiteInt32:
+ static_assert(sizeof(jint) == 4,
+ "Interal error: Java int not compatible with kTfLiteInt");
+ return 4;
+ case kTfLiteUInt8:
+ static_assert(sizeof(jbyte) == 1,
+ "Interal error: Java byte not compatible with "
+ "kTfLiteUInt8");
+ return 1;
+ case kTfLiteInt64:
+ static_assert(sizeof(jlong) == 8,
+ "Interal error: Java long not compatible with "
+ "kTfLiteInt64");
+ return 8;
+ default:
+ return 0;
+ }
+}
+
size_t writeOneDimensionalArray(JNIEnv* env, jobject object, TfLiteType type,
void* dst, size_t dst_size) {
jarray array = static_cast<jarray>(object);
@@ -141,48 +170,6 @@ size_t readMultiDimensionalArray(JNIEnv* env, TfLiteType data_type, char* src,
}
}
-} // namespace
-
-size_t elementByteSize(TfLiteType data_type) {
- // The code in this file makes the assumption that the
- // TensorFlow TF_DataTypes and the Java primitive types
- // have the same byte sizes. Validate that:
- switch (data_type) {
- case kTfLiteFloat32:
- static_assert(sizeof(jfloat) == 4,
- "Interal error: Java float not compatible with "
- "kTfLiteFloat");
- return 4;
- case kTfLiteInt32:
- static_assert(sizeof(jint) == 4,
- "Interal error: Java int not compatible with kTfLiteInt");
- return 4;
- case kTfLiteUInt8:
- static_assert(sizeof(jbyte) == 1,
- "Interal error: Java byte not compatible with "
- "kTfLiteUInt8");
- return 1;
- case kTfLiteInt64:
- static_assert(sizeof(jlong) == 8,
- "Interal error: Java long not compatible with "
- "kTfLiteInt64");
- return 8;
- default:
- return 0;
- }
-}
-
-size_t writeByteBuffer(JNIEnv* env, jobject object, char** dst, int dst_size) {
- char* buf = static_cast<char*>(env->GetDirectBufferAddress(object));
- if (!buf) {
- throwException(env, kIllegalArgumentException,
- "Input ByteBuffer is not a direct buffer");
- return 0;
- }
- *dst = buf;
- return dst_size;
-}
-
size_t writeMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type,
int dims_left, char** dst, int dst_size) {
if (dims_left <= 1) {
@@ -203,16 +190,37 @@ size_t writeMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type,
}
}
+} // namespace
+
JNIEXPORT jobject JNICALL Java_org_tensorflow_lite_Tensor_buffer(JNIEnv* env,
jclass clazz,
jlong handle) {
TfLiteTensor* tensor = convertLongToTensor(env, handle);
if (tensor == nullptr) return nullptr;
-
+ if (tensor->data.raw == nullptr) {
+ throwException(env, kIllegalArgumentException,
+ "Internal error: Tensor hasn't been allocated.");
+ return nullptr;
+ }
return env->NewDirectByteBuffer(static_cast<void*>(tensor->data.raw),
static_cast<jlong>(tensor->bytes));
}
+JNIEXPORT void JNICALL Java_org_tensorflow_lite_Tensor_writeDirectBuffer(
+ JNIEnv* env, jclass clazz, jlong handle, jobject src) {
+ TfLiteTensor* tensor = convertLongToTensor(env, handle);
+ if (tensor == nullptr) return;
+
+ char* src_data_raw = static_cast<char*>(env->GetDirectBufferAddress(src));
+ if (!src_data_raw) {
+ throwException(env, kIllegalArgumentException,
+ "Input ByteBuffer is not a direct buffer");
+ return;
+ }
+
+ tensor->data.raw = src_data_raw;
+}
+
JNIEXPORT void JNICALL
Java_org_tensorflow_lite_Tensor_readMultiDimensionalArray(JNIEnv* env,
jclass clazz,
@@ -230,6 +238,27 @@ Java_org_tensorflow_lite_Tensor_readMultiDimensionalArray(JNIEnv* env,
num_dims, static_cast<jarray>(value));
}
+JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_Tensor_writeMultiDimensionalArray(JNIEnv* env,
+ jclass clazz,
+ jlong handle,
+ jobject src) {
+ TfLiteTensor* tensor = convertLongToTensor(env, handle);
+ if (tensor == nullptr) return;
+ if (tensor->data.raw == nullptr) {
+ throwException(env, kIllegalArgumentException,
+ "Internal error: Target Tensor hasn't been allocated.");
+ return;
+ }
+ if (tensor->dims->size == 0) {
+ throwException(env, kIllegalArgumentException,
+ "Internal error: Cannot copy empty/scalar Tensors.");
+ return;
+ }
+ writeMultiDimensionalArray(env, src, tensor->type, tensor->dims->size,
+ &tensor->data.raw, tensor->bytes);
+}
+
JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_dtype(JNIEnv* env,
jclass clazz,
jlong handle) {
@@ -247,3 +276,11 @@ Java_org_tensorflow_lite_Tensor_shape(JNIEnv* env, jclass clazz, jlong handle) {
env->SetIntArrayRegion(result, 0, num_dims, tensor->dims->data);
return result;
}
+
+JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_numBytes(JNIEnv* env,
+ jclass clazz,
+ jlong handle) {
+ const TfLiteTensor* tensor = convertLongToTensor(env, handle);
+ if (tensor == nullptr) return 0;
+ return static_cast<jint>(tensor->bytes);
+}
diff --git a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h
index 9ba95d9ac4..06e2546af8 100644
--- a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h
+++ b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h
@@ -34,6 +34,14 @@ JNIEXPORT jobject JNICALL Java_org_tensorflow_lite_Tensor_buffer(JNIEnv* env,
/*
* Class: org_tensorflow_lite_Tensor
+ * Method: writeDirectBuffer
+ * Signature: (JLjava/nio/ByteBuffer;)
+ */
+JNIEXPORT void JNICALL Java_org_tensorflow_lite_Tensor_writeDirectBuffer(
+ JNIEnv* env, jclass clazz, jlong handle, jobject src);
+
+/*
+ * Class: org_tensorflow_lite_Tensor
* Method: dtype
* Signature: (J)I
*/
@@ -52,6 +60,15 @@ JNIEXPORT jintArray JNICALL Java_org_tensorflow_lite_Tensor_shape(JNIEnv* env,
/*
* Class: org_tensorflow_lite_Tensor
+ * Method: numBytes
+ * Signature: (J)I
+ */
+JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_numBytes(JNIEnv* env,
+ jclass clazz,
+ jlong handle);
+
+/*
+ * Class: org_tensorflow_lite_Tensor
* Method: readMultiDimensionalArray
* Signature: (JLjava/lang/Object;)
*/
@@ -59,23 +76,18 @@ JNIEXPORT void JNICALL
Java_org_tensorflow_lite_Tensor_readMultiDimensionalArray(JNIEnv* env,
jclass clazz,
jlong handle,
- jobject value);
+ jobject dst);
/*
- * Finds the size of each data type.
- */
-size_t elementByteSize(TfLiteType data_type);
-
-/*
- * Writes data of a ByteBuffer into dest.
- */
-size_t writeByteBuffer(JNIEnv* env, jobject object, char** dst, int dst_size);
-
-/*
- * Writes a multi-dimensional array into dest.
+ * Class: org_tensorflow_lite_Tensor
+ * Method: writeMultidimensionalArray
+ * Signature: (JLjava/lang/Object;)
*/
-size_t writeMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type,
- int dims_left, char** dst, int dst_size);
+JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_Tensor_writeMultiDimensionalArray(JNIEnv* env,
+ jclass clazz,
+ jlong handle,
+ jobject src);
#ifdef __cplusplus
} // extern "C"
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
index e6deadffe2..d66a73db94 100644
--- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
@@ -221,7 +221,9 @@ public final class InterpreterTest {
assertThat(e)
.hasMessageThat()
.contains(
- "DataType (2) of input data does not match with the DataType (1) of model inputs.");
+ "Cannot convert between a TensorFlowLite tensor with type "
+ + "FLOAT32 and a Java object of type [[[[I (which is compatible with the"
+ + " TensorFlowLite type INT32)");
}
interpreter.close();
}
@@ -241,8 +243,8 @@ public final class InterpreterTest {
assertThat(e)
.hasMessageThat()
.contains(
- "Cannot convert an TensorFlowLite tensor with type "
- + "FLOAT32 to a Java object of type [[[[I (which is compatible with the"
+ "Cannot convert between a TensorFlowLite tensor with type "
+ + "FLOAT32 and a Java object of type [[[[I (which is compatible with the"
+ " TensorFlowLite type INT32)");
}
interpreter.close();
@@ -329,4 +331,11 @@ public final class InterpreterTest {
interpreter.close();
fileChannel.close();
}
+
+ @Test
+ public void testRedundantClose() throws Exception {
+ Interpreter interpreter = new Interpreter(MODEL_FILE);
+ interpreter.close();
+ interpreter.close();
+ }
}
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
index 029e5853e2..9c4a5acd79 100644
--- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
@@ -20,6 +20,8 @@ import static org.junit.Assert.fail;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
+import java.util.HashMap;
+import java.util.Map;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -101,10 +103,10 @@ public final class NativeInterpreterWrapperTest {
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
- Tensor[] outputs = wrapper.run(inputs);
- assertThat(outputs.length).isEqualTo(1);
float[][][][] parsedOutputs = new float[2][8][8][3];
- outputs[0].copyTo(parsedOutputs);
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
+ wrapper.run(inputs, outputs);
float[] outputOneD = parsedOutputs[0][0][0];
float[] expected = {3.69f, -19.62f, 23.43f};
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
@@ -119,11 +121,11 @@ public final class NativeInterpreterWrapperTest {
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
- Tensor[] outputs = wrapper.run(inputs);
- assertThat(outputs).hasLength(1);
ByteBuffer parsedOutput =
ByteBuffer.allocateDirect(2 * 8 * 8 * 3 * 4).order(ByteOrder.nativeOrder());
- outputs[0].copyTo(parsedOutput);
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutput);
+ wrapper.run(inputs, outputs);
float[] outputOneD = {
parsedOutput.getFloat(0), parsedOutput.getFloat(4), parsedOutput.getFloat(8)
};
@@ -140,17 +142,16 @@ public final class NativeInterpreterWrapperTest {
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
- Tensor[] outputs = wrapper.run(inputs);
- assertThat(outputs.length).isEqualTo(1);
float[][][][] parsedOutputs = new float[2][8][8][3];
- outputs[0].copyTo(parsedOutputs);
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
+ wrapper.run(inputs, outputs);
float[] outputOneD = parsedOutputs[0][0][0];
float[] expected = {3.69f, -19.62f, 23.43f};
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
- outputs = wrapper.run(inputs);
- assertThat(outputs.length).isEqualTo(1);
parsedOutputs = new float[2][8][8][3];
- outputs[0].copyTo(parsedOutputs);
+ outputs.put(0, parsedOutputs);
+ wrapper.run(inputs, outputs);
outputOneD = parsedOutputs[0][0][0];
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
wrapper.close();
@@ -164,10 +165,10 @@ public final class NativeInterpreterWrapperTest {
int[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
int[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
- Tensor[] outputs = wrapper.run(inputs);
- assertThat(outputs.length).isEqualTo(1);
int[][][][] parsedOutputs = new int[2][4][4][12];
- outputs[0].copyTo(parsedOutputs);
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
+ wrapper.run(inputs, outputs);
int[] outputOneD = parsedOutputs[0][0][0];
int[] expected = {3, 7, -4, 3, 7, -4, 3, 7, -4, 3, 7, -4};
assertThat(outputOneD).isEqualTo(expected);
@@ -182,10 +183,10 @@ public final class NativeInterpreterWrapperTest {
long[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
long[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
- Tensor[] outputs = wrapper.run(inputs);
- assertThat(outputs.length).isEqualTo(1);
long[][][][] parsedOutputs = new long[2][4][4][12];
- outputs[0].copyTo(parsedOutputs);
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
+ wrapper.run(inputs, outputs);
long[] outputOneD = parsedOutputs[0][0][0];
long[] expected = {-892834092L, 923423L, 2123918239018L, -892834092L, 923423L, 2123918239018L,
-892834092L, 923423L, 2123918239018L, -892834092L, 923423L, 2123918239018L};
@@ -203,10 +204,10 @@ public final class NativeInterpreterWrapperTest {
Object[] inputs = {fourD};
int[] inputDims = {2, 8, 8, 3};
wrapper.resizeInput(0, inputDims);
- Tensor[] outputs = wrapper.run(inputs);
- assertThat(outputs.length).isEqualTo(1);
byte[][][][] parsedOutputs = new byte[2][4][4][12];
- outputs[0].copyTo(parsedOutputs);
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
+ wrapper.run(inputs, outputs);
byte[] outputOneD = parsedOutputs[0][0][0];
byte[] expected = {(byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0,
(byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0};
@@ -229,13 +230,14 @@ public final class NativeInterpreterWrapperTest {
}
}
}
+ bbuf.rewind();
Object[] inputs = {bbuf};
int[] inputDims = {2, 8, 8, 3};
wrapper.resizeInput(0, inputDims);
- Tensor[] outputs = wrapper.run(inputs);
- assertThat(outputs.length).isEqualTo(1);
byte[][][][] parsedOutputs = new byte[2][4][4][12];
- outputs[0].copyTo(parsedOutputs);
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
+ wrapper.run(inputs, outputs);
byte[] outputOneD = parsedOutputs[0][0][0];
byte[] expected = {
(byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0,
@@ -261,21 +263,22 @@ public final class NativeInterpreterWrapperTest {
}
}
Object[] inputs = {bbuf};
+ float[][][][] parsedOutputs = new float[4][8][8][3];
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
try {
- wrapper.run(inputs);
+ wrapper.run(inputs, outputs);
fail();
} catch (IllegalArgumentException e) {
assertThat(e)
.hasMessageThat()
.contains(
- "Failed to get input dimensions. 0-th input should have 768 bytes, but found 3072 bytes");
+ "Cannot convert between a TensorFlowLite buffer with 768 bytes and a "
+ + "ByteBuffer with 3072 bytes.");
}
int[] inputDims = {4, 8, 8, 3};
wrapper.resizeInput(0, inputDims);
- Tensor[] outputs = wrapper.run(inputs);
- assertThat(outputs.length).isEqualTo(1);
- float[][][][] parsedOutputs = new float[4][8][8][3];
- outputs[0].copyTo(parsedOutputs);
+ wrapper.run(inputs, outputs);
float[] outputOneD = parsedOutputs[0][0][0];
float[] expected = {3.69f, -19.62f, 23.43f};
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
@@ -288,14 +291,18 @@ public final class NativeInterpreterWrapperTest {
ByteBuffer bbuf = ByteBuffer.allocateDirect(2 * 7 * 8 * 3);
bbuf.order(ByteOrder.nativeOrder());
Object[] inputs = {bbuf};
+ Map<Integer, Object> outputs = new HashMap<>();
+ ByteBuffer parsedOutput = ByteBuffer.allocateDirect(2 * 7 * 8 * 3);
+ outputs.put(0, parsedOutput);
try {
- wrapper.run(inputs);
+ wrapper.run(inputs, outputs);
fail();
} catch (IllegalArgumentException e) {
assertThat(e)
.hasMessageThat()
.contains(
- "Failed to get input dimensions. 0-th input should have 192 bytes, but found 336 bytes.");
+ "Cannot convert between a TensorFlowLite buffer with 192 bytes and a "
+ + "ByteBuffer with 336 bytes.");
}
wrapper.close();
}
@@ -308,14 +315,18 @@ public final class NativeInterpreterWrapperTest {
int[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
int[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
+ int[][][][] parsedOutputs = new int[2][8][8][3];
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
try {
- wrapper.run(inputs);
+ wrapper.run(inputs, outputs);
fail();
} catch (IllegalArgumentException e) {
assertThat(e)
.hasMessageThat()
.contains(
- "DataType (2) of input data does not match with the DataType (1) of model inputs.");
+ "Cannot convert between a TensorFlowLite tensor with type FLOAT32 and a Java object "
+ + "of type [[[[I (which is compatible with the TensorFlowLite type INT32)");
}
wrapper.close();
}
@@ -329,8 +340,11 @@ public final class NativeInterpreterWrapperTest {
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
+ float[][][][] parsedOutputs = new float[2][8][8][3];
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
try {
- wrapper.run(inputs);
+ wrapper.run(inputs, outputs);
fail();
} catch (IllegalArgumentException e) {
assertThat(e).hasMessageThat().contains("Invalid handle to Interpreter.");
@@ -342,7 +356,7 @@ public final class NativeInterpreterWrapperTest {
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
try {
Object[] inputs = {};
- wrapper.run(inputs);
+ wrapper.run(inputs, null);
fail();
} catch (IllegalArgumentException e) {
assertThat(e).hasMessageThat().contains("Inputs should not be null or empty.");
@@ -358,11 +372,14 @@ public final class NativeInterpreterWrapperTest {
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD, fourD};
+ float[][][][] parsedOutputs = new float[2][8][8][3];
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
try {
- wrapper.run(inputs);
+ wrapper.run(inputs, outputs);
fail();
} catch (IllegalArgumentException e) {
- assertThat(e).hasMessageThat().contains("Expected num of inputs is 1 but got 2");
+ assertThat(e).hasMessageThat().contains("Invalid input Tensor index: 1");
}
wrapper.close();
}
@@ -374,13 +391,18 @@ public final class NativeInterpreterWrapperTest {
float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD};
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
Object[] inputs = {threeD};
+ float[][][][] parsedOutputs = new float[2][8][8][3];
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
try {
- wrapper.run(inputs);
+ wrapper.run(inputs, outputs);
fail();
} catch (IllegalArgumentException e) {
assertThat(e)
.hasMessageThat()
- .contains("0-th input should have 4 dimensions, but found 3 dimensions");
+ .contains(
+ "Cannot copy between a TensorFlowLite tensor with shape [8, 7, 3] and a "
+ + "Java object with shape [2, 8, 8, 3].");
}
wrapper.close();
}
@@ -393,92 +415,23 @@ public final class NativeInterpreterWrapperTest {
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
+ float[][][][] parsedOutputs = new float[2][8][8][3];
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
try {
- wrapper.run(inputs);
+ wrapper.run(inputs, outputs);
fail();
} catch (IllegalArgumentException e) {
assertThat(e)
.hasMessageThat()
- .contains("0-th input dimension should be [?,8,8,3], but found [?,8,7,3]");
+ .contains(
+ "Cannot copy between a TensorFlowLite tensor with shape [2, 8, 7, 3] and a "
+ + "Java object with shape [2, 8, 8, 3].");
}
wrapper.close();
}
@Test
- public void testNumElements() {
- int[] shape = {2, 3, 4};
- int num = NativeInterpreterWrapper.numElements(shape);
- assertThat(num).isEqualTo(24);
- shape = null;
- num = NativeInterpreterWrapper.numElements(shape);
- assertThat(num).isEqualTo(0);
- }
-
- @Test
- public void testIsNonEmtpyArray() {
- assertThat(NativeInterpreterWrapper.isNonEmptyArray(null)).isFalse();
- assertThat(NativeInterpreterWrapper.isNonEmptyArray(3.2)).isFalse();
- int[] emptyArray = {};
- assertThat(NativeInterpreterWrapper.isNonEmptyArray(emptyArray)).isFalse();
- int[] validArray = {9, 5, 2, 1};
- assertThat(NativeInterpreterWrapper.isNonEmptyArray(validArray)).isTrue();
- }
-
- @Test
- public void testDataTypeOf() {
- float[] testEmtpyArray = {};
- DataType dataType = NativeInterpreterWrapper.dataTypeOf(testEmtpyArray);
- assertThat(dataType).isEqualTo(DataType.FLOAT32);
- float[] testFloatArray = {0.783f, 0.251f};
- dataType = NativeInterpreterWrapper.dataTypeOf(testFloatArray);
- assertThat(dataType).isEqualTo(DataType.FLOAT32);
- float[][] testMultiDimArray = {testFloatArray, testFloatArray, testFloatArray};
- dataType = NativeInterpreterWrapper.dataTypeOf(testFloatArray);
- assertThat(dataType).isEqualTo(DataType.FLOAT32);
- try {
- double[] testDoubleArray = {0.783, 0.251};
- NativeInterpreterWrapper.dataTypeOf(testDoubleArray);
- fail();
- } catch (IllegalArgumentException e) {
- assertThat(e).hasMessageThat().contains("cannot resolve DataType of");
- }
- try {
- Float[] testBoxedArray = {0.783f, 0.251f};
- NativeInterpreterWrapper.dataTypeOf(testBoxedArray);
- fail();
- } catch (IllegalArgumentException e) {
- assertThat(e).hasMessageThat().contains("cannot resolve DataType of [Ljava.lang.Float;");
- }
- }
-
- @Test
- public void testNumDimensions() {
- int scalar = 1;
- assertThat(NativeInterpreterWrapper.numDimensions(scalar)).isEqualTo(0);
- int[][] array = {{2, 4}, {1, 9}};
- assertThat(NativeInterpreterWrapper.numDimensions(array)).isEqualTo(2);
- try {
- int[] emptyArray = {};
- NativeInterpreterWrapper.numDimensions(emptyArray);
- fail();
- } catch (IllegalArgumentException e) {
- assertThat(e).hasMessageThat().contains("Array lengths cannot be 0.");
- }
- }
-
- @Test
- public void testFillShape() {
- int[][][] array = {{{23}, {14}, {87}}, {{12}, {42}, {31}}};
- int num = NativeInterpreterWrapper.numDimensions(array);
- int[] shape = new int[num];
- NativeInterpreterWrapper.fillShape(array, 0, shape);
- assertThat(num).isEqualTo(3);
- assertThat(shape[0]).isEqualTo(2);
- assertThat(shape[1]).isEqualTo(3);
- assertThat(shape[2]).isEqualTo(1);
- }
-
- @Test
public void testGetInferenceLatency() {
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
float[] oneD = {1.23f, 6.54f, 7.81f};
@@ -486,8 +439,10 @@ public final class NativeInterpreterWrapperTest {
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
- Tensor[] outputs = wrapper.run(inputs);
- assertThat(outputs.length).isEqualTo(1);
+ float[][][][] parsedOutputs = new float[2][8][8][3];
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
+ wrapper.run(inputs, outputs);
assertThat(wrapper.getLastNativeInferenceDurationNanoseconds()).isGreaterThan(0L);
wrapper.close();
}
@@ -507,13 +462,14 @@ public final class NativeInterpreterWrapperTest {
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
+ float[][][][] parsedOutputs = new float[2][8][8][3];
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
try {
- wrapper.run(inputs);
+ wrapper.run(inputs, outputs);
fail();
} catch (IllegalArgumentException e) {
- assertThat(e)
- .hasMessageThat()
- .contains("0-th input dimension should be [?,8,8,3], but found [?,8,7,3]");
+ // Expected.
}
assertThat(wrapper.getLastNativeInferenceDurationNanoseconds()).isNull();
wrapper.close();
@@ -523,41 +479,7 @@ public final class NativeInterpreterWrapperTest {
public void testGetInputDims() {
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
int[] expectedDims = {1, 8, 8, 3};
- assertThat(wrapper.getInputDims(0)).isEqualTo(expectedDims);
- wrapper.close();
- }
-
- @Test
- public void testGetInputDimsOutOfRange() {
- NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
- try {
- wrapper.getInputDims(-1);
- fail();
- } catch (IllegalArgumentException e) {
- assertThat(e).hasMessageThat().contains("Out of range");
- }
- try {
- wrapper.getInputDims(1);
- fail();
- } catch (IllegalArgumentException e) {
- assertThat(e).hasMessageThat().contains("Out of range");
- }
- wrapper.close();
- }
-
- @Test
- public void testGetOutputDataType() {
- NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
- assertThat(wrapper.getOutputDataType(0)).contains("float");
- wrapper.close();
- wrapper = new NativeInterpreterWrapper(LONG_MODEL_PATH);
- assertThat(wrapper.getOutputDataType(0)).contains("long");
- wrapper.close();
- wrapper = new NativeInterpreterWrapper(INT_MODEL_PATH);
- assertThat(wrapper.getOutputDataType(0)).contains("int");
- wrapper.close();
- wrapper = new NativeInterpreterWrapper(BYTE_MODEL_PATH);
- assertThat(wrapper.getOutputDataType(0)).contains("byte");
+ assertThat(wrapper.getInputTensor(0).shape()).isEqualTo(expectedDims);
wrapper.close();
}
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java
index dd9d37eeda..71ef044943 100644
--- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java
@@ -18,9 +18,10 @@ package org.tensorflow.lite;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.fail;
-import java.nio.BufferOverflowException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
+import java.util.HashMap;
+import java.util.Map;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
@@ -35,7 +36,7 @@ public final class TensorTest {
"tensorflow/contrib/lite/java/src/testdata/add.bin";
private NativeInterpreterWrapper wrapper;
- private long nativeHandle;
+ private Tensor tensor;
@Before
public void setUp() {
@@ -45,8 +46,10 @@ public final class TensorTest {
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
- Tensor[] outputs = wrapper.run(inputs);
- nativeHandle = outputs[0].nativeHandle;
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, new float[2][8][8][3]);
+ wrapper.run(inputs, outputs);
+ tensor = wrapper.getOutputTensor(0);
}
@After
@@ -55,17 +58,16 @@ public final class TensorTest {
}
@Test
- public void testFromHandle() throws Exception {
- Tensor tensor = Tensor.fromHandle(nativeHandle);
+ public void testBasic() throws Exception {
assertThat(tensor).isNotNull();
int[] expectedShape = {2, 8, 8, 3};
- assertThat(tensor.shapeCopy).isEqualTo(expectedShape);
- assertThat(tensor.dtype).isEqualTo(DataType.FLOAT32);
+ assertThat(tensor.shape()).isEqualTo(expectedShape);
+ assertThat(tensor.dataType()).isEqualTo(DataType.FLOAT32);
+ assertThat(tensor.numBytes()).isEqualTo(2 * 8 * 8 * 3 * 4);
}
@Test
public void testCopyTo() {
- Tensor tensor = Tensor.fromHandle(nativeHandle);
float[][][][] parsedOutputs = new float[2][8][8][3];
tensor.copyTo(parsedOutputs);
float[] outputOneD = parsedOutputs[0][0][0];
@@ -75,7 +77,6 @@ public final class TensorTest {
@Test
public void testCopyToByteBuffer() {
- Tensor tensor = Tensor.fromHandle(nativeHandle);
ByteBuffer parsedOutput =
ByteBuffer.allocateDirect(2 * 8 * 8 * 3 * 4).order(ByteOrder.nativeOrder());
tensor.copyTo(parsedOutput);
@@ -89,19 +90,17 @@ public final class TensorTest {
@Test
public void testCopyToInvalidByteBuffer() {
- Tensor tensor = Tensor.fromHandle(nativeHandle);
ByteBuffer parsedOutput = ByteBuffer.allocateDirect(3 * 4).order(ByteOrder.nativeOrder());
try {
tensor.copyTo(parsedOutput);
fail();
- } catch (BufferOverflowException e) {
+ } catch (IllegalArgumentException e) {
// Expected.
}
}
@Test
public void testCopyToWrongType() {
- Tensor tensor = Tensor.fromHandle(nativeHandle);
int[][][][] parsedOutputs = new int[2][8][8][3];
try {
tensor.copyTo(parsedOutputs);
@@ -110,15 +109,13 @@ public final class TensorTest {
assertThat(e)
.hasMessageThat()
.contains(
- "Cannot convert an TensorFlowLite tensor with type "
- + "FLOAT32 to a Java object of type [[[[I (which is compatible with the TensorFlowLite "
- + "type INT32)");
+ "Cannot convert between a TensorFlowLite tensor with type FLOAT32 and a Java object "
+ + "of type [[[[I (which is compatible with the TensorFlowLite type INT32)");
}
}
@Test
public void testCopyToWrongShape() {
- Tensor tensor = Tensor.fromHandle(nativeHandle);
float[][][][] parsedOutputs = new float[1][8][8][3];
try {
tensor.copyTo(parsedOutputs);
@@ -127,8 +124,104 @@ public final class TensorTest {
assertThat(e)
.hasMessageThat()
.contains(
- "Shape of output target [1, 8, 8, 3] does not match "
- + "with the shape of the Tensor [2, 8, 8, 3].");
+ "Cannot copy between a TensorFlowLite tensor with shape [2, 8, 8, 3] "
+ + "and a Java object with shape [1, 8, 8, 3].");
+ }
+ }
+
+ @Test
+ public void testSetTo() {
+ float[][][][] input = new float[2][8][8][3];
+ float[][][][] output = new float[2][8][8][3];
+ ByteBuffer inputByteBuffer =
+ ByteBuffer.allocateDirect(2 * 8 * 8 * 3 * 4).order(ByteOrder.nativeOrder());
+
+ input[0][0][0][0] = 2.0f;
+ tensor.setTo(input);
+ tensor.copyTo(output);
+ assertThat(output[0][0][0][0]).isEqualTo(2.0f);
+
+ inputByteBuffer.putFloat(0, 3.0f);
+ tensor.setTo(inputByteBuffer);
+ tensor.copyTo(output);
+ assertThat(output[0][0][0][0]).isEqualTo(3.0f);
+ }
+
+ @Test
+ public void testSetToInvalidByteBuffer() {
+ ByteBuffer input = ByteBuffer.allocateDirect(3 * 4).order(ByteOrder.nativeOrder());
+ try {
+ tensor.setTo(input);
+ fail();
+ } catch (IllegalArgumentException e) {
+ // Success.
+ }
+ }
+
+ @Test
+ public void testGetInputShapeIfDifferent() {
+ ByteBuffer bytBufferInput = ByteBuffer.allocateDirect(3 * 4).order(ByteOrder.nativeOrder());
+ assertThat(tensor.getInputShapeIfDifferent(bytBufferInput)).isNull();
+
+ float[][][][] sameShapeInput = new float[2][8][8][3];
+ assertThat(tensor.getInputShapeIfDifferent(sameShapeInput)).isNull();
+
+ float[][][][] differentShapeInput = new float[1][8][8][3];
+ assertThat(tensor.getInputShapeIfDifferent(differentShapeInput))
+ .isEqualTo(new int[] {1, 8, 8, 3});
+ }
+
+ @Test
+ public void testDataTypeOf() {
+ float[] testEmptyArray = {};
+ DataType dataType = Tensor.dataTypeOf(testEmptyArray);
+ assertThat(dataType).isEqualTo(DataType.FLOAT32);
+ float[] testFloatArray = {0.783f, 0.251f};
+ dataType = Tensor.dataTypeOf(testFloatArray);
+ assertThat(dataType).isEqualTo(DataType.FLOAT32);
+ float[][] testMultiDimArray = {testFloatArray, testFloatArray, testFloatArray};
+ dataType = Tensor.dataTypeOf(testFloatArray);
+ assertThat(dataType).isEqualTo(DataType.FLOAT32);
+ try {
+ double[] testDoubleArray = {0.783, 0.251};
+ Tensor.dataTypeOf(testDoubleArray);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e).hasMessageThat().contains("cannot resolve DataType of");
+ }
+ try {
+ Float[] testBoxedArray = {0.783f, 0.251f};
+ Tensor.dataTypeOf(testBoxedArray);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e).hasMessageThat().contains("cannot resolve DataType of [Ljava.lang.Float;");
}
}
+
+ @Test
+ public void testNumDimensions() {
+ int scalar = 1;
+ assertThat(Tensor.numDimensions(scalar)).isEqualTo(0);
+ int[][] array = {{2, 4}, {1, 9}};
+ assertThat(Tensor.numDimensions(array)).isEqualTo(2);
+ try {
+ int[] emptyArray = {};
+ Tensor.numDimensions(emptyArray);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e).hasMessageThat().contains("Array lengths cannot be 0.");
+ }
+ }
+
+ @Test
+ public void testFillShape() {
+ int[][][] array = {{{23}, {14}, {87}}, {{12}, {42}, {31}}};
+ int num = Tensor.numDimensions(array);
+ int[] shape = new int[num];
+ Tensor.fillShape(array, 0, shape);
+ assertThat(num).isEqualTo(3);
+ assertThat(shape[0]).isEqualTo(2);
+ assertThat(shape[1]).isEqualTo(3);
+ assertThat(shape[2]).isEqualTo(1);
+ }
}
diff --git a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java
index 3aef0c3bb6..c23521c077 100644
--- a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java
+++ b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java
@@ -58,7 +58,7 @@ public class TestHelper {
*/
public static int[] getInputDims(Interpreter interpreter, int index) {
if (interpreter != null && interpreter.wrapper != null) {
- return interpreter.wrapper.getInputDims(index);
+ return interpreter.wrapper.getInputTensor(index).shape();
} else {
throw new IllegalArgumentException(
"Interpreter has not initialized;" + " Failed to get input dimensions.");
@@ -77,7 +77,7 @@ public class TestHelper {
*/
public static String getOutputDataType(Interpreter interpreter, int index) {
if (interpreter != null && interpreter.wrapper != null) {
- return interpreter.wrapper.getOutputDataType(index);
+ return interpreter.wrapper.getOutputTensor(index).dataType().toStringName();
} else {
throw new IllegalArgumentException(
"Interpreter has not initialized;" + " Failed to get output data type.");
diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD
index 27b8a16e15..33594c138b 100644
--- a/tensorflow/contrib/lite/kernels/BUILD
+++ b/tensorflow/contrib/lite/kernels/BUILD
@@ -46,11 +46,17 @@ cc_library(
hdrs = [
"eigen_support.h",
],
- copts = tflite_copts(),
+ copts = tflite_copts() + [
+ "-Wno-error=reorder",
+ ] + select({
+ "//tensorflow:ios": ["-Wno-error=invalid-partial-specialization"],
+ "//conditions:default": [
+ ],
+ }),
deps = [
":op_macros",
"//tensorflow/contrib/lite:context",
- "//third_party/eigen3",
+ "//tensorflow/contrib/lite/kernels/internal:optimized",
],
)
@@ -130,7 +136,7 @@ cc_library(
srcs = [
"activations.cc",
"add.cc",
- "arg_max.cc",
+ "arg_min_max.cc",
"audio_spectrogram.cc",
"basic_rnn.cc",
"batch_to_space_nd.cc",
@@ -149,6 +155,7 @@ cc_library(
"embedding_lookup_sparse.cc",
"exp.cc",
"expand_dims.cc",
+ "fake_quant.cc",
"floor.cc",
"fully_connected.cc",
"gather.cc",
@@ -290,9 +297,9 @@ tf_cc_test(
)
tf_cc_test(
- name = "arg_max_test",
+ name = "arg_min_max_test",
size = "small",
- srcs = ["arg_max_test.cc"],
+ srcs = ["arg_min_max_test.cc"],
tags = [
"tflite_not_portable_ios",
],
@@ -558,6 +565,19 @@ tf_cc_test(
)
tf_cc_test(
+ name = "fake_quant_test",
+ size = "small",
+ srcs = ["fake_quant_test.cc"],
+ tags = ["tflite_not_portable_ios"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
name = "maximum_minimum_test",
size = "small",
srcs = ["maximum_minimum_test.cc"],
@@ -964,7 +984,6 @@ tf_cc_test(
":builtin_ops",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite/kernels:test_util",
- "@com_google_absl//absl/memory",
"@com_google_googletest//:gtest",
],
)
diff --git a/tensorflow/contrib/lite/kernels/arg_max.cc b/tensorflow/contrib/lite/kernels/arg_min_max.cc
index 26f57e8896..4f30d09030 100644
--- a/tensorflow/contrib/lite/kernels/arg_max.cc
+++ b/tensorflow/contrib/lite/kernels/arg_min_max.cc
@@ -23,7 +23,7 @@ limitations under the License.
namespace tflite {
namespace ops {
namespace builtin {
-namespace arg_max {
+namespace arg_min_max {
constexpr int kInputTensor = 0;
constexpr int kAxis = 1;
@@ -80,30 +80,39 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return context->ResizeTensor(context, output, output_size);
}
+template <typename T>
+std::function<bool(T, T)> GetComparefunction(bool is_arg_max) {
+ if (is_arg_max) {
+ return std::greater<T>();
+ } else {
+ return std::less<T>();
+ }
+}
+
// The current impl actually ignores the axis argument.
// Only determine the index of the maximum value in the last dimension.
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* axis = GetInput(context, node, kAxis);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
-#define TF_LITE_ARG_MAX(data_type, axis_type, output_type) \
- optimized_ops::ArgMax(GetTensorData<axis_type>(axis), \
- GetTensorData<data_type>(input), GetTensorDims(input), \
- GetTensorData<output_type>(output), \
- GetTensorDims(output))
+#define TF_LITE_ARG_MIN_MAX(data_type, axis_type, output_type) \
+ optimized_ops::ArgMinMax( \
+ GetTensorData<axis_type>(axis), GetTensorData<data_type>(input), \
+ GetTensorDims(input), GetTensorData<output_type>(output), \
+ GetTensorDims(output), GetComparefunction<data_type>(is_arg_max))
if (axis->type == kTfLiteInt32) {
switch (output->type) {
case kTfLiteInt32: {
switch (input->type) {
case kTfLiteFloat32:
- TF_LITE_ARG_MAX(float, int32_t, int32_t);
+ TF_LITE_ARG_MIN_MAX(float, int32_t, int32_t);
break;
case kTfLiteUInt8:
- TF_LITE_ARG_MAX(uint8_t, int32_t, int32_t);
+ TF_LITE_ARG_MIN_MAX(uint8_t, int32_t, int32_t);
break;
case kTfLiteInt32:
- TF_LITE_ARG_MAX(int32_t, int32_t, int32_t);
+ TF_LITE_ARG_MIN_MAX(int32_t, int32_t, int32_t);
break;
default:
return kTfLiteError;
@@ -112,13 +121,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteInt64: {
switch (input->type) {
case kTfLiteFloat32:
- TF_LITE_ARG_MAX(float, int32_t, int64_t);
+ TF_LITE_ARG_MIN_MAX(float, int32_t, int64_t);
break;
case kTfLiteUInt8:
- TF_LITE_ARG_MAX(uint8_t, int32_t, int64_t);
+ TF_LITE_ARG_MIN_MAX(uint8_t, int32_t, int64_t);
break;
case kTfLiteInt32:
- TF_LITE_ARG_MAX(int32_t, int32_t, int64_t);
+ TF_LITE_ARG_MIN_MAX(int32_t, int32_t, int64_t);
break;
default:
return kTfLiteError;
@@ -132,13 +141,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteInt32: {
switch (input->type) {
case kTfLiteFloat32:
- TF_LITE_ARG_MAX(float, int64_t, int32_t);
+ TF_LITE_ARG_MIN_MAX(float, int64_t, int32_t);
break;
case kTfLiteUInt8:
- TF_LITE_ARG_MAX(uint8_t, int64_t, int32_t);
+ TF_LITE_ARG_MIN_MAX(uint8_t, int64_t, int32_t);
break;
case kTfLiteInt32:
- TF_LITE_ARG_MAX(int32_t, int64_t, int32_t);
+ TF_LITE_ARG_MIN_MAX(int32_t, int64_t, int32_t);
break;
default:
return kTfLiteError;
@@ -147,13 +156,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteInt64: {
switch (input->type) {
case kTfLiteFloat32:
- TF_LITE_ARG_MAX(float, int64_t, int64_t);
+ TF_LITE_ARG_MIN_MAX(float, int64_t, int64_t);
break;
case kTfLiteUInt8:
- TF_LITE_ARG_MAX(uint8_t, int64_t, int64_t);
+ TF_LITE_ARG_MIN_MAX(uint8_t, int64_t, int64_t);
break;
case kTfLiteInt32:
- TF_LITE_ARG_MAX(int32_t, int64_t, int64_t);
+ TF_LITE_ARG_MIN_MAX(int32_t, int64_t, int64_t);
break;
default:
return kTfLiteError;
@@ -163,16 +172,30 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteError;
}
}
-#undef TF_LITE_ARG_MAX
+#undef TF_LITE_ARG_MIN_MAX
return kTfLiteOk;
}
-} // namespace arg_max
+TfLiteStatus ArgMinEval(TfLiteContext* context, TfLiteNode* node) {
+ return Eval(context, node, false);
+}
+
+TfLiteStatus ArgMaxEval(TfLiteContext* context, TfLiteNode* node) {
+ return Eval(context, node, true);
+}
+
+} // namespace arg_min_max
TfLiteRegistration* Register_ARG_MAX() {
- static TfLiteRegistration r = {nullptr, nullptr, arg_max::Prepare,
- arg_max::Eval};
+ static TfLiteRegistration r = {nullptr, nullptr, arg_min_max::Prepare,
+ arg_min_max::ArgMaxEval};
+ return &r;
+}
+
+TfLiteRegistration* Register_ARG_MIN() {
+ static TfLiteRegistration r = {nullptr, nullptr, arg_min_max::Prepare,
+ arg_min_max::ArgMinEval};
return &r;
}
diff --git a/tensorflow/contrib/lite/kernels/arg_max_test.cc b/tensorflow/contrib/lite/kernels/arg_min_max_test.cc
index 31b15fe19a..90e5fdc532 100644
--- a/tensorflow/contrib/lite/kernels/arg_max_test.cc
+++ b/tensorflow/contrib/lite/kernels/arg_min_max_test.cc
@@ -24,16 +24,13 @@ namespace {
using ::testing::ElementsAreArray;
template <typename T>
-class ArgMaxOpModel : public SingleOpModel {
+class ArgBaseOpModel : public SingleOpModel {
public:
- ArgMaxOpModel(std::initializer_list<int> input_shape, TensorType input_type,
- TensorType output_type, TensorType index_output_type) {
+ ArgBaseOpModel(std::initializer_list<int> input_shape, TensorType input_type,
+ TensorType output_type, TensorType index_output_type) {
input_ = AddInput(input_type);
axis_ = AddInput(TensorType_INT32);
output_ = AddOutput(output_type);
- SetBuiltinOp(BuiltinOperator_ARG_MAX, BuiltinOptions_ArgMaxOptions,
- CreateArgMaxOptions(builder_, index_output_type).Union());
- BuildInterpreter({input_shape, {1, 1, 1, 1}});
}
int input() { return input_; }
@@ -42,12 +39,42 @@ class ArgMaxOpModel : public SingleOpModel {
std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
- private:
+ protected:
int input_;
int axis_;
int output_;
};
+template <typename T>
+class ArgMaxOpModel : public ArgBaseOpModel<T> {
+ public:
+ ArgMaxOpModel(std::initializer_list<int> input_shape, TensorType input_type,
+ TensorType output_type, TensorType index_output_type)
+ : ArgBaseOpModel<T>(input_shape, input_type, output_type,
+ index_output_type) {
+ ArgBaseOpModel<T>::SetBuiltinOp(
+ BuiltinOperator_ARG_MAX, BuiltinOptions_ArgMaxOptions,
+ CreateArgMaxOptions(ArgBaseOpModel<T>::builder_, index_output_type)
+ .Union());
+ ArgBaseOpModel<T>::BuildInterpreter({input_shape, {1, 1, 1, 1}});
+ }
+};
+
+template <typename T>
+class ArgMinOpModel : public ArgBaseOpModel<T> {
+ public:
+ ArgMinOpModel(std::initializer_list<int> input_shape, TensorType input_type,
+ TensorType output_type, TensorType index_output_type)
+ : ArgBaseOpModel<T>(input_shape, input_type, output_type,
+ index_output_type) {
+ ArgBaseOpModel<T>::SetBuiltinOp(
+ BuiltinOperator_ARG_MIN, BuiltinOptions_ArgMinOptions,
+ CreateArgMinOptions(ArgBaseOpModel<T>::builder_, index_output_type)
+ .Union());
+ ArgBaseOpModel<T>::BuildInterpreter({input_shape, {1, 1, 1, 1}});
+ }
+};
+
TEST(ArgMaxOpTest, GetMaxArgFloat) {
ArgMaxOpModel<int32_t> model({1, 1, 1, 4}, TensorType_FLOAT32,
TensorType_INT32, TensorType_INT32);
@@ -96,6 +123,54 @@ TEST(ArgMaxOpTest, GetMaxArgOutput64) {
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 1}));
}
+TEST(ArgMinOpTest, GetMinArgFloat) {
+ ArgMinOpModel<int32_t> model({1, 1, 1, 4}, TensorType_FLOAT32,
+ TensorType_INT32, TensorType_INT32);
+ model.PopulateTensor<float>(model.input(), {0.1, 0.9, 0.7, 0.3});
+ // Currently only support the last dimension.
+ model.PopulateTensor<int>(model.axis(), {3});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({0}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 1}));
+}
+
+TEST(ArgMinOpTest, GetMinArgInt) {
+ ArgMinOpModel<int32_t> model({1, 1, 1, 4}, TensorType_INT32, TensorType_INT32,
+ TensorType_INT32);
+ model.PopulateTensor<int>(model.input(), {1, 9, 7, 3});
+ // Currently only support the last dimension.
+ model.PopulateTensor<int>(model.axis(), {3});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({0}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 1}));
+}
+
+TEST(ArgMinOpTest, GetMinArgMulDimensions) {
+ ArgMinOpModel<int32_t> model({1, 1, 2, 4}, TensorType_INT32, TensorType_INT32,
+ TensorType_INT32);
+ model.PopulateTensor<int>(model.input(), {1, 2, 7, 8, 1, 9, 7, 3});
+ // Currently only support the last dimension.
+ model.PopulateTensor<int>(model.axis(), {3});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({0, 0}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 1}));
+}
+
+TEST(ArgMinOpTest, GetMinArgOutput64) {
+ ArgMinOpModel<int64_t> model({1, 1, 2, 4}, TensorType_INT32, TensorType_INT64,
+ TensorType_INT64);
+ model.PopulateTensor<int>(model.input(), {10, 2, 7, 8, 1, 9, 7, 3});
+ // Currently only support the last dimension.
+ model.PopulateTensor<int>(model.axis(), {3});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 1}));
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
index 3425288f02..14a19aeef3 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
@@ -276,27 +276,33 @@ TfLiteStatus CheckLstmTensorDimensions(
TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TfLiteNode* node, int n_input,
int n_output, int n_cell) {
- CheckLstmTensorDimensions(
- context, node, n_input, n_output, n_cell, kFwInputToInputWeightsTensor,
- kFwInputToForgetWeightsTensor, kFwInputToCellWeightsTensor,
- kFwInputToOutputWeightsTensor, kFwRecurrentToInputWeightsTensor,
- kFwRecurrentToForgetWeightsTensor, kFwRecurrentToCellWeightsTensor,
- kFwRecurrentToOutputWeightsTensor, kFwCellToInputWeightsTensor,
- kFwCellToForgetWeightsTensor, kFwCellToOutputWeightsTensor,
- kFwInputGateBiasTensor, kFwForgetGateBiasTensor, kFwCellGateBiasTensor,
- kFwOutputGateBiasTensor, kFwProjectionWeightsTensor,
- kFwProjectionBiasTensor);
-
- CheckLstmTensorDimensions(
- context, node, n_input, n_output, n_cell, kBwInputToInputWeightsTensor,
- kBwInputToForgetWeightsTensor, kBwInputToCellWeightsTensor,
- kBwInputToOutputWeightsTensor, kBwRecurrentToInputWeightsTensor,
- kBwRecurrentToForgetWeightsTensor, kBwRecurrentToCellWeightsTensor,
- kBwRecurrentToOutputWeightsTensor, kBwCellToInputWeightsTensor,
- kBwCellToForgetWeightsTensor, kBwCellToOutputWeightsTensor,
- kBwInputGateBiasTensor, kBwForgetGateBiasTensor, kBwCellGateBiasTensor,
- kBwOutputGateBiasTensor, kBwProjectionWeightsTensor,
- kBwProjectionBiasTensor);
+ TF_LITE_ENSURE_OK(
+ context,
+ CheckLstmTensorDimensions(
+ context, node, n_input, n_output, n_cell,
+ kFwInputToInputWeightsTensor, kFwInputToForgetWeightsTensor,
+ kFwInputToCellWeightsTensor, kFwInputToOutputWeightsTensor,
+ kFwRecurrentToInputWeightsTensor, kFwRecurrentToForgetWeightsTensor,
+ kFwRecurrentToCellWeightsTensor, kFwRecurrentToOutputWeightsTensor,
+ kFwCellToInputWeightsTensor, kFwCellToForgetWeightsTensor,
+ kFwCellToOutputWeightsTensor, kFwInputGateBiasTensor,
+ kFwForgetGateBiasTensor, kFwCellGateBiasTensor,
+ kFwOutputGateBiasTensor, kFwProjectionWeightsTensor,
+ kFwProjectionBiasTensor));
+
+ TF_LITE_ENSURE_OK(
+ context,
+ CheckLstmTensorDimensions(
+ context, node, n_input, n_output, n_cell,
+ kBwInputToInputWeightsTensor, kBwInputToForgetWeightsTensor,
+ kBwInputToCellWeightsTensor, kBwInputToOutputWeightsTensor,
+ kBwRecurrentToInputWeightsTensor, kBwRecurrentToForgetWeightsTensor,
+ kBwRecurrentToCellWeightsTensor, kBwRecurrentToOutputWeightsTensor,
+ kBwCellToInputWeightsTensor, kBwCellToForgetWeightsTensor,
+ kBwCellToOutputWeightsTensor, kBwInputGateBiasTensor,
+ kBwForgetGateBiasTensor, kBwCellGateBiasTensor,
+ kBwOutputGateBiasTensor, kBwProjectionWeightsTensor,
+ kBwProjectionBiasTensor));
// Check if Forward and Backward tensors match along required dimensions.
return kTfLiteOk;
@@ -334,7 +340,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const int n_fw_output = fw_recurrent_to_output_weights->dims->data[1];
// Check that input tensor dimensions matches with each other.
- CheckInputTensorDimensions(context, node, n_input, n_fw_output, n_fw_cell);
+ TF_LITE_ENSURE_OK(
+ context, CheckInputTensorDimensions(context, node, n_input, n_fw_output,
+ n_fw_cell));
// Get the pointer to output, state and scratch buffer tensors.
TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
@@ -404,7 +412,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const int n_bw_output = bw_recurrent_to_output_weights->dims->data[1];
// Check that input tensor dimensions matches with each other.
- CheckInputTensorDimensions(context, node, n_input, n_bw_output, n_bw_cell);
+ TF_LITE_ENSURE_OK(
+ context, CheckInputTensorDimensions(context, node, n_input, n_bw_output,
+ n_bw_cell));
// Get the pointer to output, output_state and cell_state buffer tensors.
TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc
index 0321b2e2a0..a4fe9e5550 100644
--- a/tensorflow/contrib/lite/kernels/conv.cc
+++ b/tensorflow/contrib/lite/kernels/conv.cc
@@ -418,6 +418,7 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
filter_data = GetTensorData<float>(filter);
}
multithreaded_ops::Conv(
+ *eigen_support::GetThreadPoolDevice(context),
GetTensorData<float>(input), GetTensorDims(input), filter_data,
GetTensorDims(filter), GetTensorData<float>(bias),
GetTensorDims(bias), params->stride_width, params->stride_height,
diff --git a/tensorflow/contrib/lite/kernels/eigen_support.cc b/tensorflow/contrib/lite/kernels/eigen_support.cc
index 94927cb53d..4f0d020793 100644
--- a/tensorflow/contrib/lite/kernels/eigen_support.cc
+++ b/tensorflow/contrib/lite/kernels/eigen_support.cc
@@ -14,14 +14,38 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/kernels/eigen_support.h"
-#include "third_party/eigen3/Eigen/Core"
+#include <utility>
+
+#include "tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
namespace tflite {
namespace eigen_support {
namespace {
+// We have a single global threadpool for all convolution operations. This means
+// that inferences started from different threads may block each other, but
+// since the underlying resource of CPU cores should be consumed by the
+// operations anyway, it shouldn't affect overall performance.
+class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface {
+ public:
+ // Takes ownership of 'pool'
+ explicit EigenThreadPoolWrapper(Eigen::ThreadPool* pool) : pool_(pool) {}
+ ~EigenThreadPoolWrapper() override {}
+
+ void Schedule(std::function<void()> fn) override {
+ pool_->Schedule(std::move(fn));
+ }
+ int NumThreads() const override { return pool_->NumThreads(); }
+ int CurrentThreadId() const override { return pool_->CurrentThreadId(); }
+
+ private:
+ std::unique_ptr<Eigen::ThreadPool> pool_;
+};
+
struct RefCountedEigenContext : public TfLiteExternalContext {
+ std::unique_ptr<Eigen::ThreadPoolInterface> thread_pool_wrapper;
+ std::unique_ptr<Eigen::ThreadPoolDevice> device;
int num_references = 0;
};
@@ -30,8 +54,26 @@ RefCountedEigenContext* GetEigenContext(TfLiteContext* context) {
context->GetExternalContext(context, kTfLiteEigenContext));
}
+void InitDevice(TfLiteContext* context, RefCountedEigenContext* ptr) {
+ int num_threads = 4;
+ if (context->recommended_num_threads != -1) {
+ num_threads = context->recommended_num_threads;
+ }
+ ptr->device.reset(); // destroy before we invalidate the thread pool
+ ptr->thread_pool_wrapper.reset(
+ new EigenThreadPoolWrapper(new Eigen::ThreadPool(num_threads)));
+ ptr->device.reset(
+ new Eigen::ThreadPoolDevice(ptr->thread_pool_wrapper.get(), num_threads));
+}
+
TfLiteStatus Refresh(TfLiteContext* context) {
Eigen::setNbThreads(context->recommended_num_threads);
+
+ auto* ptr = GetEigenContext(context);
+ if (ptr != nullptr) {
+ InitDevice(context, ptr);
+ }
+
return kTfLiteOk;
}
@@ -47,6 +89,7 @@ void IncrementUsageCounter(TfLiteContext* context) {
ptr->type = kTfLiteEigenContext;
ptr->Refresh = Refresh;
ptr->num_references = 0;
+ InitDevice(context, ptr);
context->SetExternalContext(context, kTfLiteEigenContext, ptr);
}
ptr->num_references++;
@@ -65,5 +108,14 @@ void DecrementUsageCounter(TfLiteContext* context) {
}
}
+const Eigen::ThreadPoolDevice* GetThreadPoolDevice(TfLiteContext* context) {
+ auto* ptr = GetEigenContext(context);
+ if (ptr == nullptr) {
+ TF_LITE_FATAL(
+ "Call to GetFromContext() not preceded by IncrementUsageCounter()");
+ }
+ return ptr->device.get();
+}
+
} // namespace eigen_support
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/eigen_support.h b/tensorflow/contrib/lite/kernels/eigen_support.h
index d47e691123..ec77856b10 100644
--- a/tensorflow/contrib/lite/kernels/eigen_support.h
+++ b/tensorflow/contrib/lite/kernels/eigen_support.h
@@ -17,6 +17,10 @@ limitations under the License.
#include "tensorflow/contrib/lite/context.h"
+namespace EigenForTFLite {
+class ThreadPoolDevice;
+}
+
namespace tflite {
namespace eigen_support {
@@ -28,6 +32,9 @@ void IncrementUsageCounter(TfLiteContext* context);
// usages all temporary Eigen objects will be deleted.
void DecrementUsageCounter(TfLiteContext* context);
+const EigenForTFLite::ThreadPoolDevice* GetThreadPoolDevice(
+ TfLiteContext* context);
+
} // namespace eigen_support
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup.cc b/tensorflow/contrib/lite/kernels/embedding_lookup.cc
index 0ba170a4da..f550339d03 100644
--- a/tensorflow/contrib/lite/kernels/embedding_lookup.cc
+++ b/tensorflow/contrib/lite/kernels/embedding_lookup.cc
@@ -112,8 +112,9 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
// TODO(alanchiao): refactor scalar multiply into separate function
// for ease of adding a neon equivalent if ever necessary.
for (int j = 0; j < col_size; j++) {
+ const int8_t* value_ptr = reinterpret_cast<int8_t*>(value->data.uint8);
output->data.f[j + i * col_size] =
- value->data.uint8[j + idx * col_size] * scaling_factor;
+ value_ptr[j + idx * col_size] * scaling_factor;
}
}
}
diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc
index 04657fd863..4a88d168c6 100644
--- a/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc
+++ b/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc
@@ -107,9 +107,9 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple2DTest) {
HybridEmbeddingLookupOpModel m({3}, {3, 8});
m.SetInput({1, 0, 2});
m.SetWeight({
- 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
- 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
- 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
+ 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
+ 1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
+ 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
});
m.Invoke();
@@ -117,9 +117,9 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple2DTest) {
EXPECT_THAT(m.GetOutput(),
ElementsAreArray(ArrayFloatNear(
{
- 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
- 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
- 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
+ 1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
+ 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
+ 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
},
7.41e-03)));
}
@@ -128,9 +128,9 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple3DTest) {
HybridEmbeddingLookupOpModel m({3}, {3, 2, 4});
m.SetInput({1, 0, 2});
m.SetWeight({
- 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
- 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
- 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
+ 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
+ 1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
+ 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
});
m.Invoke();
@@ -138,9 +138,9 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple3DTest) {
EXPECT_THAT(m.GetOutput(),
ElementsAreArray(ArrayFloatNear(
{
- 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
- 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
- 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
+ 1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
+ 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
+ 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
},
7.41e-03)));
}
@@ -149,9 +149,9 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple4DTest) {
HybridEmbeddingLookupOpModel m({3}, {3, 2, 2, 2});
m.SetInput({1, 0, 2});
m.SetWeight({
- 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
- 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
- 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
+ 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
+ 1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
+ 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
});
m.Invoke();
@@ -159,9 +159,9 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple4DTest) {
EXPECT_THAT(m.GetOutput(),
ElementsAreArray(ArrayFloatNear(
{
- 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
- 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
- 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
+ 1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
+ 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
+ 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
},
7.41e-03)));
}
diff --git a/tensorflow/contrib/lite/kernels/fake_quant.cc b/tensorflow/contrib/lite/kernels/fake_quant.cc
new file mode 100644
index 0000000000..0ef1a50b30
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/fake_quant.cc
@@ -0,0 +1,92 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <string.h>
+#include <vector>
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace fake_quant {
+
+// This file has reference implementation of FakeQuant.
+enum KernelType {
+ kReference,
+};
+
+struct OpContext {
+ OpContext(TfLiteContext* context, TfLiteNode* node) {
+ input = GetInput(context, node, 0);
+ output = GetOutput(context, node, 0);
+ }
+ const TfLiteTensor* input;
+ TfLiteTensor* output;
+};
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ const auto* params =
+ reinterpret_cast<TfLiteFakeQuantParams*>(node->builtin_data);
+
+ if (params->narrow_range) {
+ context->ReportError(
+ context,
+ "narrow_range FakeQuant is not currently supported at runtime. "
+ "narrow_range is only meant to be applied to weights, not activations");
+ return kTfLiteError;
+ }
+
+ OpContext op_context(context, node);
+ TfLiteIntArray* output_dims = TfLiteIntArrayCopy(op_context.input->dims);
+ op_context.output->type = op_context.input->type;
+ return context->ResizeTensor(context, op_context.output, output_dims);
+}
+
+template <KernelType kernel_type>
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ OpContext op_context(context, node);
+
+ const auto* params =
+ reinterpret_cast<TfLiteFakeQuantParams*>(node->builtin_data);
+
+ reference_ops::FakeQuant(GetTensorData<float>(op_context.input),
+ GetTensorDims(op_context.input), params->min,
+ params->max, params->num_bits,
+ GetTensorData<float>(op_context.output),
+ GetTensorDims(op_context.output));
+
+ return kTfLiteOk;
+}
+
+} // namespace fake_quant
+
+TfLiteRegistration* Register_FAKE_QUANT_REF() {
+ static TfLiteRegistration r = {nullptr, nullptr, fake_quant::Prepare,
+ fake_quant::Eval<fake_quant::kReference>};
+ return &r;
+}
+
+TfLiteRegistration* Register_FAKE_QUANT() { return Register_FAKE_QUANT_REF(); }
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/fake_quant_test.cc b/tensorflow/contrib/lite/kernels/fake_quant_test.cc
new file mode 100644
index 0000000000..11a02f7ed7
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/fake_quant_test.cc
@@ -0,0 +1,112 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class FakeQuantOpModel : public SingleOpModel {
+ public:
+ FakeQuantOpModel(const TensorData& input, const TensorType& output, float min,
+ float max, int num_bits) {
+ input_ = AddInput(input);
+ output_ = AddOutput(output);
+ SetBuiltinOp(BuiltinOperator_FAKE_QUANT, BuiltinOptions_FakeQuantOptions,
+ CreateFakeQuantOptions(builder_, min, max, num_bits).Union());
+ BuildInterpreter({GetShape(input_)});
+ }
+
+ template <class T>
+ void SetInput(std::initializer_list<T> data) {
+ PopulateTensor(input_, data);
+ }
+
+ template <class T>
+ std::vector<T> GetOutput() {
+ return ExtractVector<T>(output_);
+ }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ protected:
+ int input_;
+ int output_;
+};
+
+TEST(FakeQuantOpTest, FloatPositiveRange8Test) {
+ std::initializer_list<float> data = {0.0, 1.0, 0.25,
+ 0.50, 0.4444444, 0.00001};
+ FakeQuantOpModel m({TensorType_FLOAT32, {3, 1, 2}}, TensorType_FLOAT32, 0.0f,
+ 1.0f, 8);
+ m.SetInput<float>(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 2}));
+ EXPECT_THAT(
+ m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear({0, 1, 0.25098, 0.498039, 0.443137, 0})));
+}
+
+TEST(FakeQuantOpTest, FloatNegativeRange8Test) {
+ std::initializer_list<float> data = {0.0, -0.9, 0.25,
+ 0.50, 0.4444444, -0.00001};
+ FakeQuantOpModel m({TensorType_FLOAT32, {3, 1, 2}}, TensorType_FLOAT32, -0.9f,
+ 0.9f, 8);
+ m.SetInput<float>(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 2}));
+ EXPECT_THAT(m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear(
+ {0, -0.896471, 0.247059, 0.501176, 0.444706, 0})));
+}
+
+TEST(FakeQuantOpTest, FloatPositiveRange16Test) {
+ std::initializer_list<float> data = {0.0, 1.0, 0.25,
+ 0.50, 0.4444444, 0.00001};
+ FakeQuantOpModel m({TensorType_FLOAT32, {3, 1, 2}}, TensorType_FLOAT32, 0.0f,
+ 1.0f, 16);
+ m.SetInput<float>(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 2}));
+ EXPECT_THAT(m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear(
+ {0, 1, 0.250004, 0.500008, 0.44445, 1.5259e-05})));
+}
+
+TEST(FakeQuantOpTest, FloatNegativeRange16Test) {
+ std::initializer_list<float> data = {0.0, -0.9, 0.25,
+ 0.50, 0.4444444, -0.00001};
+ FakeQuantOpModel m({TensorType_FLOAT32, {3, 1, 2}}, TensorType_FLOAT32, -0.9f,
+ 0.9f, 16);
+ m.SetInput<float>(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 2}));
+ EXPECT_THAT(m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear(
+ {0, -0.900014, 0.249998, 0.499995, 0.444431, 0})));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h
index 7816752132..6db41d7961 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h
@@ -61,9 +61,17 @@ inline void AveragePool(const float* input_data, const Dims<4>& input_dims,
float output_activation_min,
float output_activation_max, float* output_data,
const Dims<4>& output_dims) {
- AveragePool(input_data, DimsToShape(input_dims), stride_width, stride_height,
- pad_width, pad_height, kwidth, kheight, output_activation_min,
- output_activation_max, output_data, DimsToShape(output_dims));
+ tflite::PoolParams params;
+ params.stride_height = stride_height;
+ params.stride_width = stride_width;
+ params.filter_height = kheight;
+ params.filter_width = kwidth;
+ params.padding_values.height = pad_height;
+ params.padding_values.width = pad_width;
+ params.float_activation_min = output_activation_min;
+ params.float_activation_max = output_activation_max;
+ AveragePool(params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
}
// legacy, for compatibility with old checked-in code
@@ -96,10 +104,17 @@ inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
int32 output_activation_min,
int32 output_activation_max, uint8* output_data,
const Dims<4>& output_dims) {
- AveragePool(input_data, DimsToShape(input_dims), stride_width, stride_height,
- pad_width, pad_height, filter_width, filter_height,
- output_activation_min, output_activation_max, output_data,
- DimsToShape(output_dims));
+ tflite::PoolParams params;
+ params.stride_height = stride_height;
+ params.stride_width = stride_width;
+ params.filter_height = filter_height;
+ params.filter_width = filter_width;
+ params.padding_values.height = pad_height;
+ params.padding_values.width = pad_width;
+ params.quantized_activation_min = output_activation_min;
+ params.quantized_activation_max = output_activation_max;
+ AveragePool(params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
}
// legacy, for compatibility with old checked-in code
@@ -140,9 +155,17 @@ inline void MaxPool(const float* input_data, const Dims<4>& input_dims,
int pad_height, int kwidth, int kheight,
float output_activation_min, float output_activation_max,
float* output_data, const Dims<4>& output_dims) {
- MaxPool(input_data, DimsToShape(input_dims), stride_width, stride_height,
- pad_width, pad_height, kwidth, kheight, output_activation_min,
- output_activation_max, output_data, DimsToShape(output_dims));
+ tflite::PoolParams params;
+ params.stride_height = stride_height;
+ params.stride_width = stride_width;
+ params.filter_height = kheight;
+ params.filter_width = kwidth;
+ params.padding_values.height = pad_height;
+ params.padding_values.width = pad_width;
+ params.float_activation_min = output_activation_min;
+ params.float_activation_max = output_activation_max;
+ MaxPool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
}
// legacy, for compatibility with old checked-in code
@@ -172,10 +195,17 @@ inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
int pad_height, int filter_width, int filter_height,
int32 output_activation_min, int32 output_activation_max,
uint8* output_data, const Dims<4>& output_dims) {
- MaxPool(input_data, DimsToShape(input_dims), stride_width, stride_height,
- pad_width, pad_height, filter_width, filter_height,
- output_activation_min, output_activation_max, output_data,
- DimsToShape(output_dims));
+ PoolParams params;
+ params.stride_height = stride_height;
+ params.stride_width = stride_width;
+ params.filter_height = filter_height;
+ params.filter_width = filter_width;
+ params.padding_values.height = pad_height;
+ params.padding_values.width = pad_width;
+ params.quantized_activation_min = output_activation_min;
+ params.quantized_activation_max = output_activation_max;
+ MaxPool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
}
// legacy, for compatibility with old checked-in code
@@ -215,10 +245,17 @@ inline void L2Pool(const float* input_data, const Dims<4>& input_dims,
int pad_height, int filter_width, int filter_height,
float output_activation_min, float output_activation_max,
float* output_data, const Dims<4>& output_dims) {
- L2Pool(input_data, DimsToShape(input_dims), stride_width, stride_height,
- pad_width, pad_height, filter_width, filter_height,
- output_activation_min, output_activation_max, output_data,
- DimsToShape(output_dims));
+ PoolParams params;
+ params.stride_height = stride_height;
+ params.stride_width = stride_width;
+ params.filter_height = filter_height;
+ params.filter_width = filter_width;
+ params.padding_values.height = pad_height;
+ params.padding_values.width = pad_width;
+ params.float_activation_min = output_activation_min;
+ params.float_activation_max = output_activation_max;
+ L2Pool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
}
// legacy, for compatibility with old checked-in code
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h
index 27d9224512..4a3545d47a 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h
@@ -35,35 +35,6 @@ limitations under the License.
namespace tflite {
namespace multithreaded_ops {
-class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface {
- public:
- explicit EigenThreadPoolWrapper(Eigen::ThreadPool* pool) : pool_(pool) {}
- ~EigenThreadPoolWrapper() override {}
-
- void Schedule(std::function<void()> fn) override {
- pool_->Schedule(std::move(fn));
- }
- int NumThreads() const override { return pool_->NumThreads(); }
- int CurrentThreadId() const override { return pool_->CurrentThreadId(); }
-
- private:
- Eigen::ThreadPool* pool_ = nullptr;
-};
-
-// We have a single global threadpool for all convolution operations. This means
-// that inferences started from different threads may block each other, but
-// since the underlying resource of CPU cores should be consumed by the
-// operations anyway, it shouldn't affect overall performance.
-const Eigen::ThreadPoolDevice& GetThreadPoolDevice() {
- const int thread_count = 4;
- static Eigen::ThreadPool* tp = new Eigen::ThreadPool(thread_count);
- static EigenThreadPoolWrapper* thread_pool_wrapper =
- new EigenThreadPoolWrapper(tp);
- static Eigen::ThreadPoolDevice* device =
- new Eigen::ThreadPoolDevice(thread_pool_wrapper, thread_count);
- return *device;
-}
-
// Shorthands for the types we need when interfacing with the EigenTensor
// library.
typedef Eigen::TensorMap<
@@ -113,14 +84,13 @@ class EigenTensorConvFunctor {
}
public:
- void operator()(const T* input_data, T* im2col_buffer, int input_batches,
- int input_height, int input_width, int input_depth,
- const T* filter_data, int filter_height, int filter_width,
- int filter_count, int stride_rows, int stride_cols,
- int pad_width, int pad_height, TfLitePadding padding,
- T* output_data, int output_height, int output_width) {
- const Eigen::ThreadPoolDevice& device = GetThreadPoolDevice();
-
+ void operator()(const Eigen::ThreadPoolDevice& device, const T* input_data,
+ T* im2col_buffer, int input_batches, int input_height,
+ int input_width, int input_depth, const T* filter_data,
+ int filter_height, int filter_width, int filter_count,
+ int stride_rows, int stride_cols, int pad_width,
+ int pad_height, TfLitePadding padding, T* output_data,
+ int output_height, int output_width) {
const bool is_1x1_kernel = (filter_height == 1 && filter_width == 1 &&
stride_rows == 1 && stride_cols == 1);
if (is_1x1_kernel) {
@@ -162,11 +132,11 @@ class EigenTensorConvFunctor {
}
};
-inline void Conv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, TfLitePadding padding,
+inline void Conv(const Eigen::ThreadPoolDevice& device, const float* input_data,
+ const Dims<4>& input_dims, const float* filter_data,
+ const Dims<4>& filter_dims, const float* bias_data,
+ const Dims<4>& bias_dims, int stride_width, int stride_height,
+ int pad_width, int pad_height, TfLitePadding padding,
float output_activation_min, float output_activation_max,
float* output_data, const Dims<4>& output_dims,
float* im2col_data, const Dims<4>& im2col_dims) {
@@ -180,10 +150,11 @@ inline void Conv(const float* input_data, const Dims<4>& input_dims,
const int output_height = ArraySize(output_dims, 2);
const int output_width = ArraySize(output_dims, 1);
EigenTensorConvFunctor<float> conv_functor;
- conv_functor(input_data, im2col_data, batches, input_height, input_width,
- input_depth, filter_data, filter_height, filter_width,
- output_depth, stride_height, stride_width, pad_height, pad_width,
- padding, output_data, output_height, output_width);
+ conv_functor(device, input_data, im2col_data, batches, input_height,
+ input_width, input_depth, filter_data, filter_height,
+ filter_width, output_depth, stride_height, stride_width,
+ pad_height, pad_width, padding, output_data, output_height,
+ output_width);
optimized_ops::AddBiasAndEvalActivationFunction(
bias_data, bias_dims, output_data, output_dims, output_activation_min,
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
index 5ba7e2af9b..8c57c987d7 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
@@ -55,83 +55,33 @@ void NeonMatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows,
const int postamble_start =
m_cols - (m_cols & (kFloatWeightsPerNeonLane - 1));
- // The arrays used to cache the vector.
- void* aligned_vector_cache_free = nullptr;
- float32x4_t* vector_cache_float32x4 =
- reinterpret_cast<float32x4_t*>(aligned_alloc(
- sizeof(float32x4_t), (postamble_start >> 2) * sizeof(float32x4_t),
- &aligned_vector_cache_free));
-
- const int kUnrollSize = 2;
for (int b = 0; b < n_batch; b++) {
float* result_in_batch = result + b * m_rows * result_stride;
const float* vector_in_batch = vector + b * m_cols;
+ const float* matrix_row = matrix;
- const float* matrix_ptr0 = matrix;
- // If there is only 1 row, we don't want to assign an illegal pointer.
- const float* matrix_ptr1 = nullptr;
- if (m_rows > 1) {
- matrix_ptr1 = matrix + m_cols;
- }
-
- // Cache the vector.
- for (int c = 0; c < postamble_start; c += kFloatWeightsPerNeonLane) {
- vector_cache_float32x4[c >> 2] = vld1q_f32(vector_in_batch + c);
- }
-
- // Main matrix by vector multiplication loop, which handles two rows of
- // matrix by vector multiplication.
- for (int r = 0; r < (m_rows & ~(kUnrollSize - 1)); r += kUnrollSize) {
- float32x4_t acc0_32x4 = vmovq_n_f32(0.0);
- float32x4_t acc1_32x4 = vmovq_n_f32(0.0);
- for (int c = 0; c < postamble_start; c += kFloatWeightsPerNeonLane) {
- float32x4_t temp = vector_cache_float32x4[c >> 2];
- // Load 4 float values from vector1 and vector2 and accumulator.
- float32x4_t v0_f32x4 = vld1q_f32(matrix_ptr0 + c);
- float32x4_t v1_f32x4 = vld1q_f32(matrix_ptr1 + c);
- // Vector multiply-accumulate 4 float
- acc0_32x4 = vmlaq_f32(acc0_32x4, v0_f32x4, temp);
- acc1_32x4 = vmlaq_f32(acc1_32x4, v1_f32x4, temp);
- }
- // Add the 4 intermediate sum values to get the final dot-prod value for
- // this column.
- *result_in_batch +=
- (vgetq_lane_f32(acc0_32x4, 0) + vgetq_lane_f32(acc0_32x4, 1) +
- vgetq_lane_f32(acc0_32x4, 2) + vgetq_lane_f32(acc0_32x4, 3));
- *(result_in_batch + result_stride) +=
- (vgetq_lane_f32(acc1_32x4, 0) + vgetq_lane_f32(acc1_32x4, 1) +
- vgetq_lane_f32(acc1_32x4, 2) + vgetq_lane_f32(acc1_32x4, 3));
- for (int c = postamble_start; c < m_cols; c++) {
- *result_in_batch += matrix_ptr0[c] * vector_in_batch[c];
- *(result_in_batch + result_stride) +=
- matrix_ptr1[c] * vector_in_batch[c];
- }
- matrix_ptr0 += kUnrollSize * m_cols;
- matrix_ptr1 += kUnrollSize * m_cols;
- result_in_batch += kUnrollSize * result_stride;
- }
- for (int r = (m_rows & ~(kUnrollSize - 1)); r < m_rows; r++) {
- float32x4_t acc0_32x4 = vmovq_n_f32(0.0);
+ // Main matrix by vector multiplication loop
+ for (int r = 0; r < m_rows; r++) {
+ float32x4_t acc_32x4 = vmovq_n_f32(0.0);
for (int c = 0; c < postamble_start; c += kFloatWeightsPerNeonLane) {
- float32x4_t temp = vector_cache_float32x4[c >> 2];
- // Load 4 float values from vector1 and vector2 and accumulator.
- float32x4_t v0_f32x4 = vld1q_f32(matrix_ptr0 + c);
- // Vector multiply-accumulate 4 float
- acc0_32x4 = vmlaq_f32(acc0_32x4, v0_f32x4, temp);
+ // Load 4 float values from vector and matrix row.
+ float32x4_t vector_f32x4 = vld1q_f32(vector_in_batch + c);
+ float32x4_t matrix_f32x4 = vld1q_f32(matrix_row + c);
+ // Multiply the vector and matrix row and add to accumulator.
+ acc_32x4 = vmlaq_f32(acc_32x4, matrix_f32x4, vector_f32x4);
}
// Add the 4 intermediate sum values to get the final dot-prod value for
// this column.
*result_in_batch +=
- (vgetq_lane_f32(acc0_32x4, 0) + vgetq_lane_f32(acc0_32x4, 1) +
- vgetq_lane_f32(acc0_32x4, 2) + vgetq_lane_f32(acc0_32x4, 3));
+ (vgetq_lane_f32(acc_32x4, 0) + vgetq_lane_f32(acc_32x4, 1) +
+ vgetq_lane_f32(acc_32x4, 2) + vgetq_lane_f32(acc_32x4, 3));
for (int c = postamble_start; c < m_cols; c++) {
- *result_in_batch += matrix_ptr0[c] * vector_in_batch[c];
+ *result_in_batch += matrix_row[c] * vector_in_batch[c];
}
- matrix_ptr0 += m_cols;
+ matrix_row += m_cols;
result_in_batch += result_stride;
}
}
- free(aligned_vector_cache_free);
}
void NeonMatrixBatchVectorMultiplyAccumulate(
@@ -296,17 +246,6 @@ void NeonVectorBatchVectorCwiseProductAccumulate(const float* vector,
const int postamble_start =
v_size - (v_size & (kFloatWeightsPerNeonLane - 1));
- // The arrays used to cache the vector.
- void* aligned_vector_cache_free = nullptr;
- float32x4_t* vector_cache_float32x4 =
- reinterpret_cast<float32x4_t*>(aligned_alloc(
- sizeof(float32x4_t), (postamble_start >> 2) * sizeof(float32x4_t),
- &aligned_vector_cache_free));
-
- for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) {
- vector_cache_float32x4[v >> 2] = vld1q_f32(vector + v);
- }
-
float* result_ptr = result;
const float* batch_vector_ptr = batch_vector;
for (int b = 0; b < n_batch; b++) {
@@ -314,9 +253,9 @@ void NeonVectorBatchVectorCwiseProductAccumulate(const float* vector,
// Load from memory to vectors.
float32x4_t result_f32x4 = vld1q_f32(result_ptr + v);
float32x4_t batch_vector_f32x4 = vld1q_f32(batch_vector_ptr + v);
+ float32x4_t vector_f32x4 = vld1q_f32(vector + v);
// Multiply-accumulate.
- result_f32x4 = vmlaq_f32(result_f32x4, batch_vector_f32x4,
- vector_cache_float32x4[v >> 2]);
+ result_f32x4 = vmlaq_f32(result_f32x4, batch_vector_f32x4, vector_f32x4);
// Store.
vst1q_f32(result_ptr + v, result_f32x4);
}
@@ -328,7 +267,6 @@ void NeonVectorBatchVectorCwiseProductAccumulate(const float* vector,
result_ptr += v_size;
batch_vector_ptr += v_size;
}
- free(aligned_vector_cache_free);
}
void NeonSub1Vector(const float* vector, int v_size, float* result) {
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
index 8597707b24..c857fdf699 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
@@ -41,6 +41,7 @@ namespace optimized_ops {
// Unoptimized reference ops:
using reference_ops::ArgMax;
+using reference_ops::ArgMinMax;
using reference_ops::BroadcastGreater;
using reference_ops::BroadcastGreaterEqual;
using reference_ops::BroadcastLess;
@@ -3053,6 +3054,20 @@ void Mul(const float* input1_data, const Dims<4>& input1_dims,
output_activation_max, output_data, output_dims);
}
+inline void Mul(const int32* input1_data, const Dims<4>& input1_dims,
+ const int32* input2_data, const Dims<4>& input2_dims,
+ int32 output_activation_min, int32 output_activation_max,
+ int32* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("Mul/int32");
+
+ const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
+ for (int i = 0; i < flat_size; ++i) {
+ output_data[i] = ActivationFunctionWithMinMax(
+ input1_data[i] * input2_data[i], output_activation_min,
+ output_activation_max);
+ }
+}
+
template <FusedActivationFunctionType Ac>
void Mul(const int32* input1_data, const Dims<4>& input1_dims,
const int32* input2_data, const Dims<4>& input2_dims,
@@ -3771,21 +3786,20 @@ inline int NodeOffset(int b, int h, int w, int height, int width) {
return (b * height + h) * width + w;
}
-inline void AveragePool(const float* input_data,
- const RuntimeShape& input_shape, int stride_width,
- int stride_height, int pad_width, int pad_height,
- int kwidth, int kheight, float output_activation_min,
- float output_activation_max, float* output_data,
- const RuntimeShape& output_shape) {
+inline void AveragePool(const PoolParams& params,
+ const RuntimeShape& input_shape,
+ const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("AveragePool");
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
- const int depth = MatchingDim(input_shape, 3, output_shape, 3);
const int input_height = input_shape.Dims(1);
const int input_width = input_shape.Dims(2);
const int output_height = output_shape.Dims(1);
const int output_width = output_shape.Dims(2);
+ const int stride_height = params.stride_height;
+ const int stride_width = params.stride_width;
// TODO(benoitjacob) make this a proper reference impl without Eigen!
const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
@@ -3800,12 +3814,15 @@ inline void AveragePool(const float* input_data,
for (int w = 0; w < input_width; ++w) {
// (h_start, h_end) * (w_start, w_end) is the range that the input
// vector projects to.
- int hpad = h + pad_height;
- int wpad = w + pad_width;
- int h_start =
- (hpad < kheight) ? 0 : (hpad - kheight) / stride_height + 1;
+ int hpad = h + params.padding_values.height;
+ int wpad = w + params.padding_values.width;
+ int h_start = (hpad < params.filter_height)
+ ? 0
+ : (hpad - params.filter_height) / stride_height + 1;
int h_end = std::min(hpad / stride_height + 1, output_height);
- int w_start = (wpad < kwidth) ? 0 : (wpad - kwidth) / stride_width + 1;
+ int w_start = (wpad < params.filter_width)
+ ? 0
+ : (wpad - params.filter_width) / stride_width + 1;
int w_end = std::min(wpad / stride_width + 1, output_width);
// compute elementwise sum
for (int ph = h_start; ph < h_end; ++ph) {
@@ -3823,29 +3840,21 @@ inline void AveragePool(const float* input_data,
TFLITE_DCHECK_GT(out_count.minCoeff(), 0);
out_mat.array().rowwise() /= out_count.transpose().array();
- for (int b = 0; b < batches; ++b) {
- for (int y = 0; y < output_height; ++y) {
- for (int x = 0; x < output_width; ++x) {
- for (int c = 0; c < depth; ++c) {
- output_data[Offset(output_shape, b, y, x, c)] =
- ActivationFunctionWithMinMax(
- output_data[Offset(output_shape, b, y, x, c)],
- output_activation_min, output_activation_max);
- }
- }
- }
+ const int flat_size = output_shape.FlatSize();
+ for (int i = 0; i < flat_size; ++i) {
+ output_data[i] = ActivationFunctionWithMinMax(output_data[i],
+ params.float_activation_min,
+ params.float_activation_max);
}
}
-inline void AveragePool(const uint8* input_data,
- const RuntimeShape& input_shape, int stride_width,
- int stride_height, int pad_width, int pad_height,
- int filter_width, int filter_height,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const RuntimeShape& output_shape) {
+inline void AveragePool(const PoolParams& params,
+ const RuntimeShape& input_shape,
+ const uint8* input_data,
+ const RuntimeShape& output_shape, uint8* output_data) {
gemmlowp::ScopedProfilingLabel label("AveragePool/8bit");
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ TFLITE_DCHECK_LE(params.quantized_activation_min,
+ params.quantized_activation_max);
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
@@ -3854,17 +3863,21 @@ inline void AveragePool(const uint8* input_data,
const int input_width = input_shape.Dims(2);
const int output_height = output_shape.Dims(1);
const int output_width = output_shape.Dims(2);
+ const int stride_height = params.stride_height;
+ const int stride_width = params.stride_width;
for (int batch = 0; batch < batches; ++batch) {
for (int out_y = 0; out_y < output_height; ++out_y) {
for (int out_x = 0; out_x < output_width; ++out_x) {
- const int in_x_origin = (out_x * stride_width) - pad_width;
- const int in_y_origin = (out_y * stride_height) - pad_height;
+ const int in_x_origin =
+ (out_x * stride_width) - params.padding_values.width;
+ const int in_y_origin =
+ (out_y * stride_height) - params.padding_values.height;
const int filter_x_start = std::max(0, -in_x_origin);
const int filter_x_end =
- std::min(filter_width, input_width - in_x_origin);
+ std::min(params.filter_width, input_width - in_x_origin);
const int filter_y_start = std::max(0, -in_y_origin);
const int filter_y_end =
- std::min(filter_height, input_height - in_y_origin);
+ std::min(params.filter_height, input_height - in_y_origin);
const int filter_count =
(filter_x_end - filter_x_start) * (filter_y_end - filter_y_start);
// 1280 required by Inception v3
@@ -3912,18 +3925,18 @@ inline void AveragePool(const uint8* input_data,
output_data + Offset(output_shape, batch, out_y, out_x, 0);
int channel = 0;
#ifdef USE_NEON
-#define AVGPOOL_DIVIDING_BY(FILTER_COUNT) \
- if (filter_count == FILTER_COUNT) { \
- for (; channel <= depth - 8; channel += 8) { \
- uint16 buf[8]; \
- for (int i = 0; i < 8; i++) { \
- buf[i] = (acc[channel + i] + FILTER_COUNT / 2) / FILTER_COUNT; \
- } \
- uint8x8_t buf8 = vqmovn_u16(vld1q_u16(buf)); \
- buf8 = vmin_u8(buf8, vdup_n_u8(output_activation_max)); \
- buf8 = vmax_u8(buf8, vdup_n_u8(output_activation_min)); \
- vst1_u8(output_ptr + channel, buf8); \
- } \
+#define AVGPOOL_DIVIDING_BY(FILTER_COUNT) \
+ if (filter_count == FILTER_COUNT) { \
+ for (; channel <= depth - 8; channel += 8) { \
+ uint16 buf[8]; \
+ for (int i = 0; i < 8; i++) { \
+ buf[i] = (acc[channel + i] + FILTER_COUNT / 2) / FILTER_COUNT; \
+ } \
+ uint8x8_t buf8 = vqmovn_u16(vld1q_u16(buf)); \
+ buf8 = vmin_u8(buf8, vdup_n_u8(params.quantized_activation_max)); \
+ buf8 = vmax_u8(buf8, vdup_n_u8(params.quantized_activation_min)); \
+ vst1_u8(output_ptr + channel, buf8); \
+ } \
}
AVGPOOL_DIVIDING_BY(9)
AVGPOOL_DIVIDING_BY(15)
@@ -3934,15 +3947,15 @@ inline void AveragePool(const uint8* input_data,
buf[i] = (acc[channel + i] + filter_count / 2) / filter_count;
}
uint8x8_t buf8 = vqmovn_u16(vld1q_u16(buf));
- buf8 = vmin_u8(buf8, vdup_n_u8(output_activation_max));
- buf8 = vmax_u8(buf8, vdup_n_u8(output_activation_min));
+ buf8 = vmin_u8(buf8, vdup_n_u8(params.quantized_activation_max));
+ buf8 = vmax_u8(buf8, vdup_n_u8(params.quantized_activation_min));
vst1_u8(output_ptr + channel, buf8);
}
#endif
for (; channel < depth; ++channel) {
uint16 a = (acc[channel] + filter_count / 2) / filter_count;
- a = std::max<uint16>(a, output_activation_min);
- a = std::min<uint16>(a, output_activation_max);
+ a = std::max<uint16>(a, params.quantized_activation_min);
+ a = std::min<uint16>(a, params.quantized_activation_max);
output_ptr[channel] = static_cast<uint8>(a);
}
}
@@ -3950,20 +3963,19 @@ inline void AveragePool(const uint8* input_data,
}
}
-inline void MaxPool(const float* input_data, const RuntimeShape& input_shape,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int kwidth, int kheight,
- float output_activation_min, float output_activation_max,
- float* output_data, const RuntimeShape& output_shape) {
+inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& output_shape,
+ float* output_data) {
gemmlowp::ScopedProfilingLabel label("MaxPool");
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
- const int depth = MatchingDim(input_shape, 3, output_shape, 3);
const int input_height = input_shape.Dims(1);
const int input_width = input_shape.Dims(2);
const int output_height = output_shape.Dims(1);
const int output_width = output_shape.Dims(2);
+ const int stride_height = params.stride_height;
+ const int stride_width = params.stride_width;
const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
@@ -3974,12 +3986,15 @@ inline void MaxPool(const float* input_data, const RuntimeShape& input_shape,
for (int w = 0; w < input_width; ++w) {
// (h_start, h_end) * (w_start, w_end) is the range that the input
// vector projects to.
- int hpad = h + pad_height;
- int wpad = w + pad_width;
- int h_start =
- (hpad < kheight) ? 0 : (hpad - kheight) / stride_height + 1;
+ int hpad = h + params.padding_values.height;
+ int wpad = w + params.padding_values.width;
+ int h_start = (hpad < params.filter_height)
+ ? 0
+ : (hpad - params.filter_height) / stride_height + 1;
int h_end = std::min(hpad / stride_height + 1, output_height);
- int w_start = (wpad < kwidth) ? 0 : (wpad - kwidth) / stride_width + 1;
+ int w_start = (wpad < params.filter_width)
+ ? 0
+ : (wpad - params.filter_width) / stride_width + 1;
int w_end = std::min(wpad / stride_width + 1, output_width);
// compute elementwise sum
for (int ph = h_start; ph < h_end; ++ph) {
@@ -3994,28 +4009,20 @@ inline void MaxPool(const float* input_data, const RuntimeShape& input_shape,
}
}
}
-
- for (int b = 0; b < batches; ++b) {
- for (int y = 0; y < output_height; ++y) {
- for (int x = 0; x < output_width; ++x) {
- for (int c = 0; c < depth; ++c) {
- output_data[Offset(output_shape, b, y, x, c)] =
- ActivationFunctionWithMinMax(
- output_data[Offset(output_shape, b, y, x, c)],
- output_activation_min, output_activation_max);
- }
- }
- }
+ const int flat_size = output_shape.FlatSize();
+ for (int i = 0; i < flat_size; ++i) {
+ output_data[i] = ActivationFunctionWithMinMax(output_data[i],
+ params.float_activation_min,
+ params.float_activation_max);
}
}
-inline void MaxPool(const uint8* input_data, const RuntimeShape& input_shape,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int filter_width, int filter_height,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const RuntimeShape& output_shape) {
+inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape,
+ const uint8* input_data, const RuntimeShape& output_shape,
+ uint8* output_data) {
gemmlowp::ScopedProfilingLabel label("MaxPool/8bit");
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ TFLITE_DCHECK_LE(params.quantized_activation_min,
+ params.quantized_activation_max);
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
@@ -4024,17 +4031,21 @@ inline void MaxPool(const uint8* input_data, const RuntimeShape& input_shape,
const int input_width = input_shape.Dims(2);
const int output_height = output_shape.Dims(1);
const int output_width = output_shape.Dims(2);
+ const int stride_height = params.stride_height;
+ const int stride_width = params.stride_width;
for (int batch = 0; batch < batches; ++batch) {
for (int out_y = 0; out_y < output_height; ++out_y) {
for (int out_x = 0; out_x < output_width; ++out_x) {
- const int in_x_origin = (out_x * stride_width) - pad_width;
- const int in_y_origin = (out_y * stride_height) - pad_height;
+ const int in_x_origin =
+ (out_x * stride_width) - params.padding_values.width;
+ const int in_y_origin =
+ (out_y * stride_height) - params.padding_values.height;
const int filter_x_start = std::max(0, -in_x_origin);
const int filter_x_end =
- std::min(filter_width, input_width - in_x_origin);
+ std::min(params.filter_width, input_width - in_x_origin);
const int filter_y_start = std::max(0, -in_y_origin);
const int filter_y_end =
- std::min(filter_height, input_height - in_y_origin);
+ std::min(params.filter_height, input_height - in_y_origin);
// 2048 required by Inception v3
static constexpr int kAccBufferMaxSize = 2048;
TFLITE_DCHECK_LE(depth, kAccBufferMaxSize);
@@ -4077,21 +4088,21 @@ inline void MaxPool(const uint8* input_data, const RuntimeShape& input_shape,
#ifdef USE_NEON
for (; channel <= depth - 16; channel += 16) {
uint8x16_t a = vld1q_u8(acc + channel);
- a = vminq_u8(a, vdupq_n_u8(output_activation_max));
- a = vmaxq_u8(a, vdupq_n_u8(output_activation_min));
+ a = vminq_u8(a, vdupq_n_u8(params.quantized_activation_max));
+ a = vmaxq_u8(a, vdupq_n_u8(params.quantized_activation_min));
vst1q_u8(output_ptr + channel, a);
}
for (; channel <= depth - 8; channel += 8) {
uint8x8_t a = vld1_u8(acc + channel);
- a = vmin_u8(a, vdup_n_u8(output_activation_max));
- a = vmax_u8(a, vdup_n_u8(output_activation_min));
+ a = vmin_u8(a, vdup_n_u8(params.quantized_activation_max));
+ a = vmax_u8(a, vdup_n_u8(params.quantized_activation_min));
vst1_u8(output_ptr + channel, a);
}
#endif
for (; channel < depth; ++channel) {
uint8 a = acc[channel];
- a = std::max<uint8>(a, output_activation_min);
- a = std::min<uint8>(a, output_activation_max);
+ a = std::max<uint8>(a, params.quantized_activation_min);
+ a = std::min<uint8>(a, params.quantized_activation_max);
output_ptr[channel] = static_cast<uint8>(a);
}
}
@@ -4099,11 +4110,9 @@ inline void MaxPool(const uint8* input_data, const RuntimeShape& input_shape,
}
}
-inline void L2Pool(const float* input_data, const RuntimeShape& input_shape,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int filter_width, int filter_height,
- float output_activation_min, float output_activation_max,
- float* output_data, const RuntimeShape& output_shape) {
+inline void L2Pool(const PoolParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& output_shape,
+ float* output_data) {
gemmlowp::ScopedProfilingLabel label("L2Pool");
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
@@ -4112,6 +4121,8 @@ inline void L2Pool(const float* input_data, const RuntimeShape& input_shape,
const int input_width = input_shape.Dims(2);
const int output_height = output_shape.Dims(1);
const int output_width = output_shape.Dims(2);
+ const int stride_height = params.stride_height;
+ const int stride_width = params.stride_width;
// Actually carry out L2 Pool. Code is written in forward mode: we go through
// the input values once, and write to all the pooled regions that it maps to.
const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
@@ -4126,15 +4137,17 @@ inline void L2Pool(const float* input_data, const RuntimeShape& input_shape,
for (int w = 0; w < input_width; ++w) {
// (h_start, h_end) * (w_start, w_end) is the range that the input
// vector projects to.
- const int hpad = h + pad_height;
- const int wpad = w + pad_width;
- const int h_start = (hpad < filter_height)
- ? 0
- : (hpad - filter_height) / stride_height + 1;
+ const int hpad = h + params.padding_values.height;
+ const int wpad = w + params.padding_values.width;
+ const int h_start =
+ (hpad < params.filter_height)
+ ? 0
+ : (hpad - params.filter_height) / stride_height + 1;
const int h_end = std::min(hpad / stride_height + 1, output_height);
- const int w_start = (wpad < filter_width)
- ? 0
- : (wpad - filter_width) / stride_width + 1;
+ const int w_start =
+ (wpad < params.filter_width)
+ ? 0
+ : (wpad - params.filter_width) / stride_width + 1;
const int w_end = std::min(wpad / stride_width + 1, output_width);
// pre-compute square
const int in_offset = w + input_width * (h + input_height * b);
@@ -4155,6 +4168,13 @@ inline void L2Pool(const float* input_data, const RuntimeShape& input_shape,
out_count = out_count.array().inverse();
out_mat =
(out_mat.array().rowwise() * out_count.transpose().array()).cwiseSqrt();
+
+ const int flat_size = output_shape.FlatSize();
+ for (int i = 0; i < flat_size; ++i) {
+ output_data[i] = ActivationFunctionWithMinMax(output_data[i],
+ params.float_activation_min,
+ params.float_activation_max);
+ }
}
inline void LocalResponseNormalization(const float* input_data,
diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.h b/tensorflow/contrib/lite/kernels/internal/quantization_util.h
index 525857a2e6..9b3f1823dc 100644
--- a/tensorflow/contrib/lite/kernels/internal/quantization_util.h
+++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.h
@@ -28,8 +28,9 @@ namespace tflite {
// Given the min and max values of a float array, return
// reasonable quantization parameters to use for this array.
template <typename T>
-QuantizationParams ChooseQuantizationParams(double rmin, double rmax) {
- const T qmin = std::numeric_limits<T>::min();
+QuantizationParams ChooseQuantizationParams(double rmin, double rmax,
+ bool narrow_range) {
+ const T qmin = std::numeric_limits<T>::min() + (narrow_range ? 1 : 0);
const T qmax = std::numeric_limits<T>::max();
const double qmin_double = qmin;
const double qmax_double = qmax;
@@ -97,6 +98,11 @@ QuantizationParams ChooseQuantizationParams(double rmin, double rmax) {
return quantization_params;
}
+template <typename T>
+QuantizationParams ChooseQuantizationParams(double rmin, double rmax) {
+ return ChooseQuantizationParams<T>(rmin, rmax, false);
+}
+
// Converts a floating-point number to an integer. For all inputs x where
// static_cast<IntOut>(x) is legal according to the C++ standard, the result
// is identical to that cast (i.e. the result is x with its fractional part
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
index 878b2441b4..f715d34bc1 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
@@ -69,9 +69,17 @@ inline void AveragePool(const float* input_data, const Dims<4>& input_dims,
float output_activation_min,
float output_activation_max, float* output_data,
const Dims<4>& output_dims) {
- AveragePool(input_data, DimsToShape(input_dims), stride_width, stride_height,
- pad_width, pad_height, kwidth, kheight, output_activation_min,
- output_activation_max, output_data, DimsToShape(output_dims));
+ tflite::PoolParams params;
+ params.stride_height = stride_height;
+ params.stride_width = stride_width;
+ params.filter_height = kheight;
+ params.filter_width = kwidth;
+ params.padding_values.height = pad_height;
+ params.padding_values.width = pad_width;
+ params.float_activation_min = output_activation_min;
+ params.float_activation_max = output_activation_max;
+ AveragePool(params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
}
// legacy, for compatibility with old checked-in code
@@ -104,10 +112,17 @@ inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
int32 output_activation_min,
int32 output_activation_max, uint8* output_data,
const Dims<4>& output_dims) {
- AveragePool(input_data, DimsToShape(input_dims), stride_width, stride_height,
- pad_width, pad_height, filter_width, filter_height,
- output_activation_min, output_activation_max, output_data,
- DimsToShape(output_dims));
+ tflite::PoolParams params;
+ params.stride_height = stride_height;
+ params.stride_width = stride_width;
+ params.filter_height = filter_height;
+ params.filter_width = filter_width;
+ params.padding_values.height = pad_height;
+ params.padding_values.width = pad_width;
+ params.quantized_activation_min = output_activation_min;
+ params.quantized_activation_max = output_activation_max;
+ AveragePool(params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
}
// legacy, for compatibility with old checked-in code
@@ -148,9 +163,17 @@ inline void MaxPool(const float* input_data, const Dims<4>& input_dims,
int pad_height, int kwidth, int kheight,
float output_activation_min, float output_activation_max,
float* output_data, const Dims<4>& output_dims) {
- MaxPool(input_data, DimsToShape(input_dims), stride_width, stride_height,
- pad_width, pad_height, kwidth, kheight, output_activation_min,
- output_activation_max, output_data, DimsToShape(output_dims));
+ tflite::PoolParams params;
+ params.stride_height = stride_height;
+ params.stride_width = stride_width;
+ params.filter_height = kheight;
+ params.filter_width = kwidth;
+ params.padding_values.height = pad_height;
+ params.padding_values.width = pad_width;
+ params.float_activation_min = output_activation_min;
+ params.float_activation_max = output_activation_max;
+ MaxPool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
}
// legacy, for compatibility with old checked-in code
@@ -180,10 +203,17 @@ inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
int pad_height, int filter_width, int filter_height,
int32 output_activation_min, int32 output_activation_max,
uint8* output_data, const Dims<4>& output_dims) {
- MaxPool(input_data, DimsToShape(input_dims), stride_width, stride_height,
- pad_width, pad_height, filter_width, filter_height,
- output_activation_min, output_activation_max, output_data,
- DimsToShape(output_dims));
+ PoolParams params;
+ params.stride_height = stride_height;
+ params.stride_width = stride_width;
+ params.filter_height = filter_height;
+ params.filter_width = filter_width;
+ params.padding_values.height = pad_height;
+ params.padding_values.width = pad_width;
+ params.quantized_activation_min = output_activation_min;
+ params.quantized_activation_max = output_activation_max;
+ MaxPool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
}
// legacy, for compatibility with old checked-in code
@@ -223,10 +253,17 @@ inline void L2Pool(const float* input_data, const Dims<4>& input_dims,
int pad_height, int filter_width, int filter_height,
float output_activation_min, float output_activation_max,
float* output_data, const Dims<4>& output_dims) {
- L2Pool(input_data, DimsToShape(input_dims), stride_width, stride_height,
- pad_width, pad_height, filter_width, filter_height,
- output_activation_min, output_activation_max, output_data,
- DimsToShape(output_dims));
+ PoolParams params;
+ params.stride_height = stride_height;
+ params.stride_width = stride_width;
+ params.filter_height = filter_height;
+ params.filter_width = filter_width;
+ params.padding_values.height = pad_height;
+ params.padding_values.width = pad_width;
+ params.float_activation_min = output_activation_min;
+ params.float_activation_max = output_activation_max;
+ L2Pool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
}
// legacy, for compatibility with old checked-in code
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index 9357e7407e..2d40f1769b 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -1429,10 +1429,11 @@ inline void BroadcastAddFivefold(
output_activation_max, output_data, output_dims);
}
-inline void Mul(const float* input1_data, const Dims<4>& input1_dims,
- const float* input2_data, const Dims<4>& input2_dims,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims) {
+template <typename T>
+inline void Mul(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T output_activation_min, T output_activation_max,
+ T* output_data, const Dims<4>& output_dims) {
const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
for (int i = 0; i < flat_size; ++i) {
output_data[i] = ActivationFunctionWithMinMax(
@@ -2273,13 +2274,10 @@ inline int NodeOffset(int b, int h, int w, int height, int width) {
return (b * height + h) * width + w;
}
-inline void AveragePool(const float* input_data,
- const RuntimeShape& input_shape, int stride_width,
- int stride_height, int pad_width, int pad_height,
- int filter_width, int filter_height,
- float output_activation_min,
- float output_activation_max, float* output_data,
- const RuntimeShape& output_shape) {
+inline void AveragePool(const PoolParams& params,
+ const RuntimeShape& input_shape,
+ const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
@@ -2288,20 +2286,24 @@ inline void AveragePool(const float* input_data,
const int input_width = input_shape.Dims(2);
const int output_height = output_shape.Dims(1);
const int output_width = output_shape.Dims(2);
+ const int stride_height = params.stride_height;
+ const int stride_width = params.stride_width;
for (int batch = 0; batch < batches; ++batch) {
for (int out_y = 0; out_y < output_height; ++out_y) {
for (int out_x = 0; out_x < output_width; ++out_x) {
for (int channel = 0; channel < depth; ++channel) {
- const int in_x_origin = (out_x * stride_width) - pad_width;
- const int in_y_origin = (out_y * stride_height) - pad_height;
+ const int in_x_origin =
+ (out_x * stride_width) - params.padding_values.width;
+ const int in_y_origin =
+ (out_y * stride_height) - params.padding_values.height;
// Compute the boundaries of the filter region clamped so as to
// ensure that the filter window fits in the input array.
const int filter_x_start = std::max(0, -in_x_origin);
const int filter_x_end =
- std::min(filter_width, input_width - in_x_origin);
+ std::min(params.filter_width, input_width - in_x_origin);
const int filter_y_start = std::max(0, -in_y_origin);
const int filter_y_end =
- std::min(filter_height, input_height - in_y_origin);
+ std::min(params.filter_height, input_height - in_y_origin);
float total = 0.f;
float filter_count = 0;
for (int filter_y = filter_y_start; filter_y < filter_y_end;
@@ -2317,22 +2319,20 @@ inline void AveragePool(const float* input_data,
}
const float average = total / filter_count;
output_data[Offset(output_shape, batch, out_y, out_x, channel)] =
- ActivationFunctionWithMinMax(average, output_activation_min,
- output_activation_max);
+ ActivationFunctionWithMinMax(average, params.float_activation_min,
+ params.float_activation_max);
}
}
}
}
}
-inline void AveragePool(const uint8* input_data,
- const RuntimeShape& input_shape, int stride_width,
- int stride_height, int pad_width, int pad_height,
- int filter_width, int filter_height,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const RuntimeShape& output_shape) {
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+inline void AveragePool(const PoolParams& params,
+ const RuntimeShape& input_shape,
+ const uint8* input_data,
+ const RuntimeShape& output_shape, uint8* output_data) {
+ TFLITE_DCHECK_LE(params.quantized_activation_min,
+ params.quantized_activation_max);
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
@@ -2341,20 +2341,24 @@ inline void AveragePool(const uint8* input_data,
const int input_width = input_shape.Dims(2);
const int output_height = output_shape.Dims(1);
const int output_width = output_shape.Dims(2);
+ const int stride_height = params.stride_height;
+ const int stride_width = params.stride_width;
for (int batch = 0; batch < batches; ++batch) {
for (int out_y = 0; out_y < output_height; ++out_y) {
for (int out_x = 0; out_x < output_width; ++out_x) {
for (int channel = 0; channel < depth; ++channel) {
- const int in_x_origin = (out_x * stride_width) - pad_width;
- const int in_y_origin = (out_y * stride_height) - pad_height;
+ const int in_x_origin =
+ (out_x * stride_width) - params.padding_values.width;
+ const int in_y_origin =
+ (out_y * stride_height) - params.padding_values.height;
// Compute the boundaries of the filter region clamped so as to
// ensure that the filter window fits in the input array.
const int filter_x_start = std::max(0, -in_x_origin);
const int filter_x_end =
- std::min(filter_width, input_width - in_x_origin);
+ std::min(params.filter_width, input_width - in_x_origin);
const int filter_y_start = std::max(0, -in_y_origin);
const int filter_y_end =
- std::min(filter_height, input_height - in_y_origin);
+ std::min(params.filter_height, input_height - in_y_origin);
int32 acc = 0;
int filter_count = 0;
for (int filter_y = filter_y_start; filter_y < filter_y_end;
@@ -2369,8 +2373,8 @@ inline void AveragePool(const uint8* input_data,
}
}
acc = (acc + filter_count / 2) / filter_count;
- acc = std::max(acc, output_activation_min);
- acc = std::min(acc, output_activation_max);
+ acc = std::max(acc, params.quantized_activation_min);
+ acc = std::min(acc, params.quantized_activation_max);
output_data[Offset(output_shape, batch, out_y, out_x, channel)] =
static_cast<uint8>(acc);
}
@@ -2379,11 +2383,9 @@ inline void AveragePool(const uint8* input_data,
}
}
-inline void L2Pool(const float* input_data, const RuntimeShape& input_shape,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int filter_width, int filter_height,
- float output_activation_min, float output_activation_max,
- float* output_data, const RuntimeShape& output_shape) {
+inline void L2Pool(const PoolParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& output_shape,
+ float* output_data) {
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
@@ -2392,20 +2394,24 @@ inline void L2Pool(const float* input_data, const RuntimeShape& input_shape,
const int input_width = input_shape.Dims(2);
const int output_height = output_shape.Dims(1);
const int output_width = output_shape.Dims(2);
+ const int stride_height = params.stride_height;
+ const int stride_width = params.stride_width;
for (int batch = 0; batch < batches; ++batch) {
for (int out_y = 0; out_y < output_height; ++out_y) {
for (int out_x = 0; out_x < output_width; ++out_x) {
for (int channel = 0; channel < depth; ++channel) {
- const int in_x_origin = (out_x * stride_width) - pad_width;
- const int in_y_origin = (out_y * stride_height) - pad_height;
+ const int in_x_origin =
+ (out_x * stride_width) - params.padding_values.width;
+ const int in_y_origin =
+ (out_y * stride_height) - params.padding_values.height;
// Compute the boundaries of the filter region clamped so as to
// ensure that the filter window fits in the input array.
const int filter_x_start = std::max(0, -in_x_origin);
const int filter_x_end =
- std::min(filter_width, input_width - in_x_origin);
+ std::min(params.filter_width, input_width - in_x_origin);
const int filter_y_start = std::max(0, -in_y_origin);
const int filter_y_end =
- std::min(filter_height, input_height - in_y_origin);
+ std::min(params.filter_height, input_height - in_y_origin);
float sum_squares = 0.f;
int filter_count = 0;
for (int filter_y = filter_y_start; filter_y < filter_y_end;
@@ -2422,19 +2428,18 @@ inline void L2Pool(const float* input_data, const RuntimeShape& input_shape,
}
const float l2pool_result = std::sqrt(sum_squares / filter_count);
output_data[Offset(output_shape, batch, out_y, out_x, channel)] =
- ActivationFunctionWithMinMax(l2pool_result, output_activation_min,
- output_activation_max);
+ ActivationFunctionWithMinMax(l2pool_result,
+ params.float_activation_min,
+ params.float_activation_max);
}
}
}
}
}
-inline void MaxPool(const float* input_data, const RuntimeShape& input_shape,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int filter_width, int filter_height,
- float output_activation_min, float output_activation_max,
- float* output_data, const RuntimeShape& output_shape) {
+inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& output_shape,
+ float* output_data) {
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
@@ -2443,20 +2448,24 @@ inline void MaxPool(const float* input_data, const RuntimeShape& input_shape,
const int input_width = input_shape.Dims(2);
const int output_height = output_shape.Dims(1);
const int output_width = output_shape.Dims(2);
+ const int stride_height = params.stride_height;
+ const int stride_width = params.stride_width;
for (int batch = 0; batch < batches; ++batch) {
for (int out_y = 0; out_y < output_height; ++out_y) {
for (int out_x = 0; out_x < output_width; ++out_x) {
for (int channel = 0; channel < depth; ++channel) {
- const int in_x_origin = (out_x * stride_width) - pad_width;
- const int in_y_origin = (out_y * stride_height) - pad_height;
+ const int in_x_origin =
+ (out_x * stride_width) - params.padding_values.width;
+ const int in_y_origin =
+ (out_y * stride_height) - params.padding_values.height;
// Compute the boundaries of the filter region clamped so as to
// ensure that the filter window fits in the input array.
const int filter_x_start = std::max(0, -in_x_origin);
const int filter_x_end =
- std::min(filter_width, input_width - in_x_origin);
+ std::min(params.filter_width, input_width - in_x_origin);
const int filter_y_start = std::max(0, -in_y_origin);
const int filter_y_end =
- std::min(filter_height, input_height - in_y_origin);
+ std::min(params.filter_height, input_height - in_y_origin);
float max = std::numeric_limits<float>::lowest();
for (int filter_y = filter_y_start; filter_y < filter_y_end;
++filter_y) {
@@ -2470,22 +2479,21 @@ inline void MaxPool(const float* input_data, const RuntimeShape& input_shape,
}
}
output_data[Offset(output_shape, batch, out_y, out_x, channel)] =
- ActivationFunctionWithMinMax(max, output_activation_min,
- output_activation_max);
+ ActivationFunctionWithMinMax(max, params.float_activation_min,
+ params.float_activation_max);
}
}
}
}
}
-inline void MaxPool(const uint8* input_data, const RuntimeShape& input_shape,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int filter_width, int filter_height,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const RuntimeShape& output_shape) {
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- TFLITE_DCHECK_GE(output_activation_min, 0);
- TFLITE_DCHECK_LE(output_activation_max, 255);
+inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape,
+ const uint8* input_data, const RuntimeShape& output_shape,
+ uint8* output_data) {
+ TFLITE_DCHECK_LE(params.quantized_activation_min,
+ params.quantized_activation_max);
+ TFLITE_DCHECK_GE(params.quantized_activation_min, 0);
+ TFLITE_DCHECK_LE(params.quantized_activation_max, 255);
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
@@ -2494,20 +2502,24 @@ inline void MaxPool(const uint8* input_data, const RuntimeShape& input_shape,
const int input_width = input_shape.Dims(2);
const int output_height = output_shape.Dims(1);
const int output_width = output_shape.Dims(2);
+ const int stride_height = params.stride_height;
+ const int stride_width = params.stride_width;
for (int batch = 0; batch < batches; ++batch) {
for (int out_y = 0; out_y < output_height; ++out_y) {
for (int out_x = 0; out_x < output_width; ++out_x) {
for (int channel = 0; channel < depth; ++channel) {
- const int in_x_origin = (out_x * stride_width) - pad_width;
- const int in_y_origin = (out_y * stride_height) - pad_height;
+ const int in_x_origin =
+ (out_x * stride_width) - params.padding_values.width;
+ const int in_y_origin =
+ (out_y * stride_height) - params.padding_values.height;
// Compute the boundaries of the filter region clamped so as to
// ensure that the filter window fits in the input array.
const int filter_x_start = std::max(0, -in_x_origin);
const int filter_x_end =
- std::min(filter_width, input_width - in_x_origin);
+ std::min(params.filter_width, input_width - in_x_origin);
const int filter_y_start = std::max(0, -in_y_origin);
const int filter_y_end =
- std::min(filter_height, input_height - in_y_origin);
+ std::min(params.filter_height, input_height - in_y_origin);
uint8 max = 0;
for (int filter_y = filter_y_start; filter_y < filter_y_end;
++filter_y) {
@@ -2520,8 +2532,8 @@ inline void MaxPool(const uint8* input_data, const RuntimeShape& input_shape,
input_data[Offset(input_shape, batch, in_y, in_x, channel)]);
}
}
- max = std::max<uint8>(max, output_activation_min);
- max = std::min<uint8>(max, output_activation_max);
+ max = std::max<uint8>(max, params.quantized_activation_min);
+ max = std::min<uint8>(max, params.quantized_activation_max);
output_data[Offset(output_shape, batch, out_y, out_x, channel)] =
static_cast<uint8>(max);
}
@@ -3717,9 +3729,9 @@ void TensorFlowMaximumMinimum(const T* input1_data, const Dims<4>& input1_dims,
}
}
-template <typename T1, typename T2, typename T3>
-void ArgMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims,
- T2* output_data, const Dims<4>& output_dims) {
+template <typename T1, typename T2, typename T3, typename Cmp>
+void ArgMinMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims,
+ T2* output_data, const Dims<4>& output_dims, const Cmp& cmp) {
// The current ArgMax implemention can only determine the index of the maximum
// value in the last dimension. So the axis argument is ignored.
@@ -3732,19 +3744,28 @@ void ArgMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims,
const int depth = ArraySize(input_dims, 0);
for (int i = 0; i < outer_size; ++i) {
- auto max_value = input_data[i * depth];
- int max_index = 0;
+ auto min_max_value = input_data[i * depth];
+ int min_max_index = 0;
for (int d = 1; d < depth; ++d) {
const auto& curr_value = input_data[i * depth + d];
- if (curr_value > max_value) {
- max_value = curr_value;
- max_index = d;
+ if (cmp(curr_value, min_max_value)) {
+ min_max_value = curr_value;
+ min_max_index = d;
}
}
- output_data[i] = max_index;
+ output_data[i] = min_max_index;
}
}
+// TODO(renjieliu): Remove this one.
+template <typename T1, typename T2, typename T3>
+void ArgMax(const T3* axis, const T1* input_data,
+ const tflite::Dims<4>& input_dims, T2* output_data,
+ const tflite::Dims<4>& output_dims) {
+ ArgMinMax(axis, input_data, input_dims, output_data, output_dims,
+ std::greater<T1>());
+}
+
template <typename T>
void Transpose(const T* input, const Dims<4>& input_dims, T* output,
const Dims<4>& output_dims, const int* permuted_axes) {
diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h
index fa2420713f..737cfb69c9 100644
--- a/tensorflow/contrib/lite/kernels/internal/types.h
+++ b/tensorflow/contrib/lite/kernels/internal/types.h
@@ -23,7 +23,12 @@ limitations under the License.
namespace tflite {
enum class FusedActivationFunctionType : uint8 { kNone, kRelu6, kRelu1, kRelu };
-enum class PaddingType { kNone, kSame, kValid };
+enum class PaddingType : uint8 { kNone, kSame, kValid };
+
+struct PaddingValues {
+ int8 width;
+ int8 height;
+};
// This enumeration allows for non-default formats for the weights array
// of a fully-connected operator, allowing the use of special optimized
@@ -588,6 +593,22 @@ void ComputeStrides(Dims<N>* dims) {
}
}
+struct PoolParams {
+ FusedActivationFunctionType activation;
+ PaddingType padding_type;
+ PaddingValues padding_values;
+ int stride_height;
+ int stride_width;
+ int filter_height;
+ int filter_width;
+ // uint8, etc, inference params.
+ int32 quantized_activation_min;
+ int32 quantized_activation_max;
+ // float inference params.
+ float float_activation_min;
+ float float_activation_max;
+};
+
} // namespace tflite
#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_
diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc
index 3577ae6caa..4dfc891548 100644
--- a/tensorflow/contrib/lite/kernels/lstm.cc
+++ b/tensorflow/contrib/lite/kernels/lstm.cc
@@ -306,7 +306,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const int n_output = recurrent_to_output_weights->dims->data[1];
// Check that input tensor dimensions matches with each other.
- CheckInputTensorDimensions(context, node, n_input, n_output, n_cell);
+ TF_LITE_ENSURE_OK(context, CheckInputTensorDimensions(context, node, n_input,
+ n_output, n_cell));
// Get the pointer to output, activation_state and cell_state tensors.
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
diff --git a/tensorflow/contrib/lite/kernels/lstm_test.cc b/tensorflow/contrib/lite/kernels/lstm_test.cc
index 0b7c56133e..0266f5fe57 100644
--- a/tensorflow/contrib/lite/kernels/lstm_test.cc
+++ b/tensorflow/contrib/lite/kernels/lstm_test.cc
@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Unit test for TFLite LSTM op.
+//
+// TODO(alanchiao): add unit test with invalid input dimensions for this and its
+// variants.
#include <memory>
#include <vector>
diff --git a/tensorflow/contrib/lite/kernels/mul.cc b/tensorflow/contrib/lite/kernels/mul.cc
index 1f72f3a3c7..349f3e6726 100644
--- a/tensorflow/contrib/lite/kernels/mul.cc
+++ b/tensorflow/contrib/lite/kernels/mul.cc
@@ -100,29 +100,44 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
template <KernelType kernel_type>
-void EvalFloat(TfLiteContext* context, TfLiteNode* node,
- TfLiteMulParams* params, const OpData* data,
- const TfLiteTensor* input1, const TfLiteTensor* input2,
- TfLiteTensor* output) {
- float output_activation_min, output_activation_max;
- CalculateActivationRange(params->activation, &output_activation_min,
- &output_activation_max);
-#define TF_LITE_MUL(type, opname) \
- type::opname(GetTensorData<float>(input1), GetTensorDims(input1), \
- GetTensorData<float>(input2), GetTensorDims(input2), \
- output_activation_min, output_activation_max, \
- GetTensorData<float>(output), GetTensorDims(output))
- if (kernel_type == kReference) {
- if (data->requires_broadcast) {
- TF_LITE_MUL(reference_ops, BroadcastMul);
+void EvalMul(TfLiteContext* context, TfLiteNode* node, TfLiteMulParams* params,
+ const OpData* data, const TfLiteTensor* input1,
+ const TfLiteTensor* input2, TfLiteTensor* output) {
+#define TF_LITE_MUL(type, opname, data_type) \
+ data_type output_activation_min, output_activation_max; \
+ CalculateActivationRange(params->activation, &output_activation_min, \
+ &output_activation_max); \
+ type::opname(GetTensorData<data_type>(input1), GetTensorDims(input1), \
+ GetTensorData<data_type>(input2), GetTensorDims(input2), \
+ output_activation_min, output_activation_max, \
+ GetTensorData<data_type>(output), GetTensorDims(output))
+ if (output->type == kTfLiteInt32) {
+ if (kernel_type == kReference) {
+ if (data->requires_broadcast) {
+ TF_LITE_MUL(reference_ops, BroadcastMul, int32_t);
+ } else {
+ TF_LITE_MUL(reference_ops, Mul, int32_t);
+ }
} else {
- TF_LITE_MUL(reference_ops, Mul);
+ if (data->requires_broadcast) {
+ TF_LITE_MUL(optimized_ops, BroadcastMul, int32_t);
+ } else {
+ TF_LITE_MUL(optimized_ops, Mul, int32_t);
+ }
}
- } else {
- if (data->requires_broadcast) {
- TF_LITE_MUL(optimized_ops, BroadcastMul);
+ } else if (output->type == kTfLiteFloat32) {
+ if (kernel_type == kReference) {
+ if (data->requires_broadcast) {
+ TF_LITE_MUL(reference_ops, BroadcastMul, float);
+ } else {
+ TF_LITE_MUL(reference_ops, Mul, float);
+ }
} else {
- TF_LITE_MUL(optimized_ops, Mul);
+ if (data->requires_broadcast) {
+ TF_LITE_MUL(optimized_ops, BroadcastMul, float);
+ } else {
+ TF_LITE_MUL(optimized_ops, Mul, float);
+ }
}
}
#undef TF_LITE_MUL
@@ -194,17 +209,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- if (output->type == kTfLiteFloat32) {
- EvalFloat<kernel_type>(context, node, params, data, input1, input2, output);
+ if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) {
+ EvalMul<kernel_type>(context, node, params, data, input1, input2, output);
} else if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt16) {
TF_LITE_ENSURE_OK(
context, EvalQuantized<kernel_type>(context, node, params, data, input1,
input2, output));
} else {
- context->ReportError(
- context,
- "Mul only supports FLOAT32 and quantized UINT8 and INT16 now, got %d.",
- output->type);
+ context->ReportError(context,
+ "Mul only supports FLOAT32, INT32 and quantized UINT8 "
+ "and INT16 now, got %d.",
+ output->type);
return kTfLiteError;
}
diff --git a/tensorflow/contrib/lite/kernels/mul_test.cc b/tensorflow/contrib/lite/kernels/mul_test.cc
index 43d56e50d2..2807550a6b 100644
--- a/tensorflow/contrib/lite/kernels/mul_test.cc
+++ b/tensorflow/contrib/lite/kernels/mul_test.cc
@@ -52,6 +52,13 @@ class FloatMulOpModel : public BaseMulOpModel {
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
};
+class IntegerMulOpModel : public BaseMulOpModel {
+ public:
+ using BaseMulOpModel::BaseMulOpModel;
+
+ std::vector<int32_t> GetOutput() { return ExtractVector<int32_t>(output_); }
+};
+
// For quantized Mul, the error shouldn't exceed (2*step + step^2).
// The param min=-1.0 & max=1.0 is used in the following tests.
// The tolerance value is ~0.0157.
@@ -133,6 +140,57 @@ TEST(FloatMulOpTest, WithBroadcast) {
}
}
+TEST(IntegerMulOpTest, NoActivation) {
+ IntegerMulOpModel m({TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {1, 2, 2, 1}}, {TensorType_INT32, {}},
+ ActivationFunctionType_NONE);
+ m.PopulateTensor<int32_t>(m.input1(), {-20, 2, 7, 8});
+ m.PopulateTensor<int32_t>(m.input2(), {1, 2, 3, 5});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({-20, 4, 21, 40}));
+}
+
+TEST(IntegerMulOpTest, ActivationRELU_N1_TO_1) {
+ IntegerMulOpModel m({TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {1, 2, 2, 1}}, {TensorType_INT32, {}},
+ ActivationFunctionType_RELU_N1_TO_1);
+ m.PopulateTensor<int32_t>(m.input1(), {-20, 2, 7, 8});
+ m.PopulateTensor<int32_t>(m.input2(), {1, 2, 3, 5});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1, 1, 1, 1}));
+}
+
+TEST(IntegerMulOpTest, VariousInputShapes) {
+ std::vector<std::initializer_list<int>> test_shapes = {
+ {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
+ for (int i = 0; i < test_shapes.size(); ++i) {
+ IntegerMulOpModel m({TensorType_INT32, test_shapes[i]},
+ {TensorType_INT32, test_shapes[i]},
+ {TensorType_INT32, {}}, ActivationFunctionType_NONE);
+ m.PopulateTensor<int32_t>(m.input1(), {-20, 2, 7, 8, 11, 20});
+ m.PopulateTensor<int32_t>(m.input2(), {1, 2, 3, 5, 11, 1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({-20, 4, 21, 40, 121, 20}))
+ << "With shape number " << i;
+ }
+}
+
+TEST(IntegerMulOpTest, WithBroadcast) {
+ std::vector<std::initializer_list<int>> test_shapes = {
+ {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
+ for (int i = 0; i < test_shapes.size(); ++i) {
+ IntegerMulOpModel m({TensorType_INT32, test_shapes[i]},
+ {TensorType_INT32, {}}, // always a scalar
+ {TensorType_INT32, {}}, ActivationFunctionType_NONE);
+ m.PopulateTensor<int32_t>(m.input1(), {-20, 2, 7, 8, 11, 20});
+ m.PopulateTensor<int32_t>(m.input2(), {1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear({-20, 2, 7, 8, 11, 20})))
+ << "With shape number " << i;
+ }
+}
+
TEST(QuantizedMulOpTest, NoActivation) {
QuantizedMulOpModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0},
{TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0},
diff --git a/tensorflow/contrib/lite/kernels/pooling.cc b/tensorflow/contrib/lite/kernels/pooling.cc
index 7240fe04cc..9b0487ae16 100644
--- a/tensorflow/contrib/lite/kernels/pooling.cc
+++ b/tensorflow/contrib/lite/kernels/pooling.cc
@@ -126,13 +126,19 @@ void AverageEvalFloat(TfLiteContext* context, TfLiteNode* node,
float activation_min, activation_max;
CalculateActivationRange(params->activation, &activation_min,
&activation_max);
-#define TF_LITE_AVERAGE_POOL(type) \
- type::AveragePool(GetTensorData<float>(input), GetTensorShape(input), \
- params->stride_width, params->stride_height, \
- data->padding.width, data->padding.height, \
- params->filter_width, params->filter_height, \
- activation_min, activation_max, \
- GetTensorData<float>(output), GetTensorShape(output))
+#define TF_LITE_AVERAGE_POOL(type) \
+ tflite::PoolParams op_params; \
+ op_params.stride_height = params->stride_height; \
+ op_params.stride_width = params->stride_width; \
+ op_params.filter_height = params->filter_height; \
+ op_params.filter_width = params->filter_width; \
+ op_params.padding_values.height = data->padding.height; \
+ op_params.padding_values.width = data->padding.width; \
+ op_params.float_activation_min = activation_min; \
+ op_params.float_activation_max = activation_max; \
+ type::AveragePool(op_params, GetTensorShape(input), \
+ GetTensorData<float>(input), GetTensorShape(output), \
+ GetTensorData<float>(output))
if (kernel_type == kReference) {
TF_LITE_AVERAGE_POOL(reference_ops);
} else {
@@ -149,13 +155,19 @@ void AverageEvalQuantized(TfLiteContext* context, TfLiteNode* node,
int32_t activation_max;
CalculateActivationRangeUint8(params->activation, output, &activation_min,
&activation_max);
-#define TF_LITE_AVERAGE_POOL(type) \
- type::AveragePool(GetTensorData<uint8_t>(input), GetTensorShape(input), \
- params->stride_width, params->stride_height, \
- data->padding.width, data->padding.height, \
- params->filter_width, params->filter_height, \
- activation_min, activation_max, \
- GetTensorData<uint8_t>(output), GetTensorShape(output))
+#define TF_LITE_AVERAGE_POOL(type) \
+ tflite::PoolParams op_params; \
+ op_params.stride_height = params->stride_height; \
+ op_params.stride_width = params->stride_width; \
+ op_params.filter_height = params->filter_height; \
+ op_params.filter_width = params->filter_width; \
+ op_params.padding_values.height = data->padding.height; \
+ op_params.padding_values.width = data->padding.width; \
+ op_params.quantized_activation_min = activation_min; \
+ op_params.quantized_activation_max = activation_max; \
+ type::AveragePool(op_params, GetTensorShape(input), \
+ GetTensorData<uint8_t>(input), GetTensorShape(output), \
+ GetTensorData<uint8_t>(output))
if (kernel_type == kReference) {
TF_LITE_AVERAGE_POOL(reference_ops);
} else {
@@ -171,13 +183,18 @@ void MaxEvalFloat(TfLiteContext* context, TfLiteNode* node,
float activation_min, activation_max;
CalculateActivationRange(params->activation, &activation_min,
&activation_max);
-#define TF_LITE_MAX_POOL(type) \
- type::MaxPool(GetTensorData<float>(input), GetTensorShape(input), \
- params->stride_width, params->stride_height, \
- data->padding.width, data->padding.height, \
- params->filter_width, params->filter_height, activation_min, \
- activation_max, GetTensorData<float>(output), \
- GetTensorShape(output))
+#define TF_LITE_MAX_POOL(type) \
+ tflite::PoolParams op_params; \
+ op_params.stride_height = params->stride_height; \
+ op_params.stride_width = params->stride_width; \
+ op_params.filter_height = params->filter_height; \
+ op_params.filter_width = params->filter_width; \
+ op_params.padding_values.height = data->padding.height; \
+ op_params.padding_values.width = data->padding.width; \
+ op_params.float_activation_min = activation_min; \
+ op_params.float_activation_max = activation_max; \
+ type::MaxPool(op_params, GetTensorShape(input), GetTensorData<float>(input), \
+ GetTensorShape(output), GetTensorData<float>(output))
if (kernel_type == kReference) {
TF_LITE_MAX_POOL(reference_ops);
} else {
@@ -194,13 +211,19 @@ void MaxEvalQuantized(TfLiteContext* context, TfLiteNode* node,
int32_t activation_max;
CalculateActivationRangeUint8(params->activation, output, &activation_min,
&activation_max);
-#define TF_LITE_MAX_POOL(type) \
- type::MaxPool(GetTensorData<uint8_t>(input), GetTensorShape(input), \
- params->stride_width, params->stride_height, \
- data->padding.width, data->padding.height, \
- params->filter_width, params->filter_height, activation_min, \
- activation_max, GetTensorData<uint8_t>(output), \
- GetTensorShape(output))
+#define TF_LITE_MAX_POOL(type) \
+ tflite::PoolParams op_params; \
+ op_params.stride_height = params->stride_height; \
+ op_params.stride_width = params->stride_width; \
+ op_params.filter_height = params->filter_height; \
+ op_params.filter_width = params->filter_width; \
+ op_params.padding_values.height = data->padding.height; \
+ op_params.padding_values.width = data->padding.width; \
+ op_params.quantized_activation_min = activation_min; \
+ op_params.quantized_activation_max = activation_max; \
+ type::MaxPool(op_params, GetTensorShape(input), \
+ GetTensorData<uint8_t>(input), GetTensorShape(output), \
+ GetTensorData<uint8_t>(output))
if (kernel_type == kReference) {
TF_LITE_MAX_POOL(reference_ops);
} else {
@@ -216,13 +239,18 @@ void L2EvalFloat(TfLiteContext* context, TfLiteNode* node,
float activation_min, activation_max;
CalculateActivationRange(params->activation, &activation_min,
&activation_max);
-#define TF_LITE_L2_POOL(type) \
- type::L2Pool(GetTensorData<float>(input), GetTensorShape(input), \
- params->stride_width, params->stride_height, \
- data->padding.width, data->padding.height, \
- params->filter_width, params->filter_height, activation_min, \
- activation_max, GetTensorData<float>(output), \
- GetTensorShape(output))
+#define TF_LITE_L2_POOL(type) \
+ tflite::PoolParams op_params; \
+ op_params.stride_height = params->stride_height; \
+ op_params.stride_width = params->stride_width; \
+ op_params.filter_height = params->filter_height; \
+ op_params.filter_width = params->filter_width; \
+ op_params.padding_values.height = data->padding.height; \
+ op_params.padding_values.width = data->padding.width; \
+ op_params.float_activation_min = activation_min; \
+ op_params.float_activation_max = activation_max; \
+ type::L2Pool(op_params, GetTensorShape(input), GetTensorData<float>(input), \
+ GetTensorShape(output), GetTensorData<float>(output))
if (kernel_type == kReference) {
TF_LITE_L2_POOL(reference_ops);
} else {
diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc
index 0ca08cd8f3..22a507e6a4 100644
--- a/tensorflow/contrib/lite/kernels/register.cc
+++ b/tensorflow/contrib/lite/kernels/register.cc
@@ -82,6 +82,7 @@ TfLiteRegistration* Register_PRELU();
TfLiteRegistration* Register_MAXIMUM();
TfLiteRegistration* Register_MINIMUM();
TfLiteRegistration* Register_ARG_MAX();
+TfLiteRegistration* Register_ARG_MIN();
TfLiteRegistration* Register_GREATER();
TfLiteRegistration* Register_GREATER_EQUAL();
TfLiteRegistration* Register_LESS();
@@ -102,6 +103,7 @@ TfLiteRegistration* Register_SQRT();
TfLiteRegistration* Register_RSQRT();
TfLiteRegistration* Register_SHAPE();
TfLiteRegistration* Register_POW();
+TfLiteRegistration* Register_FAKE_QUANT();
BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_RELU, Register_RELU());
@@ -167,6 +169,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_MAXIMUM, Register_MAXIMUM());
AddBuiltin(BuiltinOperator_MINIMUM, Register_MINIMUM());
AddBuiltin(BuiltinOperator_ARG_MAX, Register_ARG_MAX());
+ AddBuiltin(BuiltinOperator_ARG_MIN, Register_ARG_MIN());
AddBuiltin(BuiltinOperator_GREATER, Register_GREATER());
AddBuiltin(BuiltinOperator_GREATER_EQUAL, Register_GREATER_EQUAL());
AddBuiltin(BuiltinOperator_LESS, Register_LESS());
@@ -187,6 +190,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_RSQRT, Register_RSQRT());
AddBuiltin(BuiltinOperator_SHAPE, Register_SHAPE());
AddBuiltin(BuiltinOperator_POW, Register_POW());
+ AddBuiltin(BuiltinOperator_FAKE_QUANT, Register_FAKE_QUANT(), 1, 2);
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.
diff --git a/tensorflow/contrib/lite/kernels/select.cc b/tensorflow/contrib/lite/kernels/select.cc
index 9b6cee3cb5..3cdb5db209 100644
--- a/tensorflow/contrib/lite/kernels/select.cc
+++ b/tensorflow/contrib/lite/kernels/select.cc
@@ -89,6 +89,9 @@ TfLiteStatus SelectEval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteUInt8: \
TF_LITE_SELECT(uint8_t, op); \
break; \
+ case kTfLiteInt16: \
+ TF_LITE_SELECT(int16_t, op); \
+ break; \
case kTfLiteInt32: \
TF_LITE_SELECT(int32_t, op); \
break; \
diff --git a/tensorflow/contrib/lite/kernels/select_test.cc b/tensorflow/contrib/lite/kernels/select_test.cc
index 4664b9acb4..5b2e61cd29 100644
--- a/tensorflow/contrib/lite/kernels/select_test.cc
+++ b/tensorflow/contrib/lite/kernels/select_test.cc
@@ -96,6 +96,19 @@ TEST(SelectOpTest, SelectUInt8) {
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
}
+TEST(SelectOpTest, SelectInt16) {
+ SelectOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, {1, 1, 1, 4},
+ TensorType_INT16);
+
+ model.PopulateTensor<bool>(model.input1(), {false, true, false, false});
+ model.PopulateTensor<int16_t>(model.input2(), {1, 2, 3, 4});
+ model.PopulateTensor<int16_t>(model.input3(), {5, 6, 7, 8});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput<int16_t>(), ElementsAreArray({5, 2, 7, 8}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+}
+
TEST(SelectOpTest, SelectInt32) {
SelectOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, {1, 1, 1, 4},
TensorType_INT32);
diff --git a/tensorflow/contrib/lite/kernels/transpose_conv.cc b/tensorflow/contrib/lite/kernels/transpose_conv.cc
index 7182374a6f..8b9deeed20 100644
--- a/tensorflow/contrib/lite/kernels/transpose_conv.cc
+++ b/tensorflow/contrib/lite/kernels/transpose_conv.cc
@@ -22,7 +22,6 @@ limitations under the License.
#include "tensorflow/contrib/lite/builtin_op_data.h"
#include "tensorflow/contrib/lite/context.h"
-#include "tensorflow/contrib/lite/kernels/eigen_support.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
@@ -39,35 +38,9 @@ constexpr int kWeightsTensor = 1;
constexpr int kDataInputTensor = 2;
constexpr int kOutputTensor = 0;
-const int kTensorNotAllocated = -1;
-
-struct OpData {
- // IDs are the arbitrary identifiers used by TF Lite to identify and access
- // memory buffers.
- int im2col_id = kTensorNotAllocated;
-
- // im2col is the only temporary currently tracked, therefore always index 0.
- // If more temporaries are added, they should be properly tracked.
- int32_t im2col_index = 0;
-};
-
-void* Init(TfLiteContext* context, const char* buffer, size_t length) {
- // This is a builtin op, so we don't use the contents in 'buffer', if any.
- // Instead, we allocate a new object to use as scratch space for im2col, and
- // to carry information from Prepare() to Eval().
- auto* data = new OpData;
- eigen_support::IncrementUsageCounter(context);
- return data;
-}
-
-void Free(TfLiteContext* context, void* buffer) {
- eigen_support::DecrementUsageCounter(context);
- delete reinterpret_cast<OpData*>(buffer);
-}
-
-TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
- const TfLiteTensor* output_shape,
- TfLiteTensor* output) {
+TfLiteStatus ResizeOutputShape(TfLiteContext* context,
+ const TfLiteTensor* output_shape,
+ TfLiteTensor* output) {
// Currently only support int32 for output shape.
if (output_shape->type != kTfLiteInt32) {
context->ReportError(context, "Output shape is %d, not int32.",
@@ -83,60 +56,15 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
return context->ResizeTensor(context, output, output_shape_array);
}
-// Allocate temporary im2col tensor.
-static TfLiteStatus AllocateIm2colTensor(TfLiteContext* context,
- TfLiteNode* node) {
- OpData* data = reinterpret_cast<OpData*>(node->user_data);
- if (data->im2col_id == kTensorNotAllocated) {
- context->AddTensors(context, 1, &data->im2col_id);
- }
-
- TfLiteIntArrayFree(node->temporaries);
- node->temporaries = TfLiteIntArrayCreate(1);
- node->temporaries->data[data->im2col_index] = data->im2col_id;
-
- return kTfLiteOk;
-}
-
-TfLiteStatus ResizeIm2ColTensor(TfLiteContext* context,
- const TfLiteTensor* output_shape,
- const TfLiteTensor* weights,
- const TfLiteTensor* input,
- TfLiteTensor* im2col) {
- if (output_shape->type != kTfLiteInt32) {
- context->ReportError(context, "im2col shape is %d, not int32.",
- output_shape->type);
- return kTfLiteError;
- }
- TF_LITE_ENSURE_EQ(context, NumElements(output_shape), 4);
- TfLiteIntArray* im2col_shape_array = TfLiteIntArrayCreate(4);
- im2col_shape_array->data[0] = output_shape->data.i32[0];
- im2col_shape_array->data[1] = output_shape->data.i32[1];
- im2col_shape_array->data[2] = output_shape->data.i32[2];
- const int input_depth = SizeOfDimension(input, 3);
- const int filter_width = SizeOfDimension(weights, 1);
- const int filter_height = SizeOfDimension(weights, 2);
- im2col_shape_array->data[3] = input_depth * filter_height * filter_width;
-
- im2col->type = input->type;
- im2col->allocation_type = kTfLiteArenaRw;
- return context->ResizeTensor(context, im2col, im2col_shape_array);
-}
-
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
- TF_LITE_ENSURE_STATUS(AllocateIm2colTensor(context, node));
-
const TfLiteTensor* output_shape =
GetInput(context, node, kOutputShapeTensor);
const TfLiteTensor* weights = GetInput(context, node, kWeightsTensor);
const TfLiteTensor* input = GetInput(context, node, kDataInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- OpData* user_data = reinterpret_cast<OpData*>(node->user_data);
- TfLiteTensor* im2col =
- &context->tensors[node->temporaries->data[user_data->im2col_index]];
TF_LITE_ENSURE_EQ(context, NumDimensions(output_shape), 1);
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
@@ -153,15 +81,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, SizeOfDimension(input, 3),
SizeOfDimension(weights, 3));
- if (IsConstantTensor(output_shape)) {
- TF_LITE_ENSURE_STATUS(ResizeOutputTensor(context, output_shape, output));
- TF_LITE_ENSURE_STATUS(
- ResizeIm2ColTensor(context, output_shape, weights, input, im2col));
- } else {
- // Defer resizing until Eval().
+ if (!IsConstantTensor(output_shape)) {
SetTensorToDynamic(output);
+ return kTfLiteOk;
}
- return kTfLiteOk;
+ return ResizeOutputShape(context, output_shape, output);
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
@@ -170,19 +94,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* weights = GetInput(context, node, kWeightsTensor);
const TfLiteTensor* input = GetInput(context, node, kDataInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- OpData* user_data = reinterpret_cast<OpData*>(node->user_data);
- TfLiteTensor* im2col =
- &context->tensors[node->temporaries->data[user_data->im2col_index]];
+
const auto* params =
reinterpret_cast<TfLiteTransposeConvParams*>(node->builtin_data);
if (IsDynamicTensor(output)) {
TF_LITE_ENSURE_OK(context,
- ResizeOutputTensor(context, output_shape, output));
- }
- if (IsDynamicTensor(im2col)) {
- TF_LITE_ENSURE_OK(context, ResizeIm2ColTensor(context, output_shape,
- weights, input, im2col));
+ ResizeOutputShape(context, output_shape, output));
}
// Get height and width of the output image.
@@ -201,12 +119,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// Currently only support float32.
switch (input->type) {
case kTfLiteFloat32:
- optimized_ops::TransposeConv(
+ reference_ops::TransposeConv(
GetTensorData<float>(input), GetTensorDims(input),
GetTensorData<float>(weights), GetTensorDims(weights), stride_width,
stride_height, padding_size.width, padding_size.height,
GetTensorData<float>(output), GetTensorDims(output),
- GetTensorData<float>(im2col), GetTensorDims(im2col));
+ // Last two args specify im2col which reference_ops ignores.
+ // (Note this does not lead to a performance regression, as the
+ // previous optimized version was just a copy of the reference code.)
+ // TODO(b/110208176): Allocate im2col tensors and switch to
+ // optimized_ops.
+ GetTensorData<float>(output), GetTensorDims(output));
break;
default:
context->ReportError(context, "Type %d, not currently supported.",
@@ -219,8 +142,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace transpose_conv
TfLiteRegistration* Register_TRANSPOSE_CONV() {
- static TfLiteRegistration r = {transpose_conv::Init, transpose_conv::Free,
- transpose_conv::Prepare, transpose_conv::Eval};
+ static TfLiteRegistration r = {nullptr, nullptr, transpose_conv::Prepare,
+ transpose_conv::Eval};
return &r;
}
diff --git a/tensorflow/contrib/lite/kernels/transpose_conv_test.cc b/tensorflow/contrib/lite/kernels/transpose_conv_test.cc
index c741df19de..55df897180 100644
--- a/tensorflow/contrib/lite/kernels/transpose_conv_test.cc
+++ b/tensorflow/contrib/lite/kernels/transpose_conv_test.cc
@@ -14,7 +14,6 @@ limitations under the License.
==============================================================================*/
#include <cstdarg>
#include <gtest/gtest.h>
-#include "absl/memory/memory.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
@@ -25,49 +24,9 @@ namespace {
using ::testing::ElementsAreArray;
-class ConstTransposeConvOpModel : public SingleOpModel {
- // Just to be extra confusing, transpose_conv has an _input_ named
- // "output_shape". This input sets the shape of the output tensor of the op.
- // In this version of the test class, "output_shape" is a constant that must
- // be specified in the constructor.
- public:
- ConstTransposeConvOpModel(TfLiteRegistration* registration,
- std::initializer_list<int> input_shape,
- std::initializer_list<int> filter_shape,
- std::initializer_list<int> output_shape_data,
- Padding padding, int stride_w, int stride_h) {
- output_shape_ = AddConstInput(TensorType_INT32, output_shape_data,
- {static_cast<int>(output_shape_data.size())});
- filter_ = AddInput(TensorType_FLOAT32);
- input_ = AddInput(TensorType_FLOAT32);
- output_ = AddOutput(TensorType_FLOAT32);
- SetBuiltinOp(
- BuiltinOperator_TRANSPOSE_CONV, BuiltinOptions_TransposeConvOptions,
- CreateTransposeConvOptions(builder_, padding, stride_w, stride_h)
- .Union());
- resolver_ = absl::make_unique<SingleOpResolver>(
- BuiltinOperator_TRANSPOSE_CONV, registration);
- BuildInterpreter({{4}, filter_shape, input_shape});
- }
-
- int output_shape() { return output_shape_; }
- int filter() { return filter_; }
- int input() { return input_; }
-
- std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
- std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
-
- private:
- int output_shape_;
- int filter_;
- int input_;
- int output_;
-};
-
class TransposeConvOpModel : public SingleOpModel {
public:
- TransposeConvOpModel(TfLiteRegistration* registration,
- std::initializer_list<int> input_shape,
+ TransposeConvOpModel(std::initializer_list<int> input_shape,
std::initializer_list<int> filter_shape, Padding padding,
int stride_w, int stride_h) {
output_shape_ = AddInput(TensorType_INT32);
@@ -78,8 +37,6 @@ class TransposeConvOpModel : public SingleOpModel {
BuiltinOperator_TRANSPOSE_CONV, BuiltinOptions_TransposeConvOptions,
CreateTransposeConvOptions(builder_, padding, stride_w, stride_h)
.Union());
- resolver_ = absl::make_unique<SingleOpResolver>(
- BuiltinOperator_TRANSPOSE_CONV, registration);
BuildInterpreter({{4}, filter_shape, input_shape});
}
@@ -97,15 +54,6 @@ class TransposeConvOpModel : public SingleOpModel {
int output_;
};
-const auto kKernelMap = new std::map<string, TfLiteRegistration*>({});
-
-class TransposeConvOpTest : public SingleOpTest {
- protected:
- const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
- return *kKernelMap;
- }
-};
-
// Test case:
// output = tf.nn.conv2d_backprop_input(
// tf.constant([ 1, 4, 4, 1 ]),
@@ -113,9 +61,8 @@ class TransposeConvOpTest : public SingleOpTest {
// tf.constant(np.arange(1, 17), shape=[ 1, 4, 4, 1 ], dtype=tf.float32),
// [1, 1, 1, 1 ],
// "SAME")
-TEST_P(TransposeConvOpTest, SimpleTest) {
- TransposeConvOpModel m(GetRegistration(), {1, 4, 4, 1}, {1, 3, 3, 1},
- Padding_SAME, 1, 1);
+TEST(TransposeConvOpModelTest, SimpleTest) {
+ TransposeConvOpModel m({1, 4, 4, 1}, {1, 3, 3, 1}, Padding_SAME, 1, 1);
m.PopulateTensor<int>(m.output_shape(), {1, 4, 4, 1});
m.PopulateTensor<float>(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9});
m.PopulateTensor<float>(
@@ -128,21 +75,6 @@ TEST_P(TransposeConvOpTest, SimpleTest) {
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
}
-// Test case: Same as above, but with a const "output_shape"
-TEST_P(TransposeConvOpTest, ConstSimpleTest) {
- ConstTransposeConvOpModel m(GetRegistration(), {1, 4, 4, 1}, {1, 4, 4, 1},
- {1, 3, 3, 1}, Padding_SAME, 1, 1);
- m.PopulateTensor<float>(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9});
- m.PopulateTensor<float>(
- m.input(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
- m.Invoke();
-
- EXPECT_THAT(m.GetOutput(),
- ElementsAreArray({29, 62, 83, 75, 99, 192, 237, 198, 207, 372,
- 417, 330, 263, 446, 485, 365}));
- EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
-}
-
// Test case:
// filter = tf.constant(np.arange(1, 19),
// shape=[ 3, 3, 1, 2 ],
@@ -155,9 +87,8 @@ TEST_P(TransposeConvOpTest, ConstSimpleTest) {
// "SAME")
// And filter value is derived by:
// filter = tf.reshape(tf.transpose(filter, perm=[3, 0, 1, 2]), shape=[18, 1])
-TEST_P(TransposeConvOpTest, TwoFiltersTest) {
- TransposeConvOpModel m(GetRegistration(), {1, 4, 4, 2}, {1, 3, 3, 2},
- Padding_SAME, 1, 1);
+TEST(TransposeConvOpModelTest, TwoFiltersTest) {
+ TransposeConvOpModel m({1, 4, 4, 2}, {1, 3, 3, 2}, Padding_SAME, 1, 1);
m.PopulateTensor<int>(m.output_shape(), {1, 4, 4, 1});
m.PopulateTensor<float>(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18});
@@ -185,9 +116,8 @@ TEST_P(TransposeConvOpTest, TwoFiltersTest) {
// "VALID")
// And filter value is derived by:
// filter = tf.reshape(tf.transpose(filter, perm=[3, 0, 1, 2]), shape=[1, 18])
-TEST_P(TransposeConvOpTest, PaddingValidTest) {
- TransposeConvOpModel m(GetRegistration(), {1, 4, 4, 2}, {1, 3, 3, 2},
- Padding_VALID, 1, 1);
+TEST(TransposeConvOpModelTest, PaddingValidTest) {
+ TransposeConvOpModel m({1, 4, 4, 2}, {1, 3, 3, 2}, Padding_VALID, 1, 1);
m.PopulateTensor<int>(m.output_shape(), {1, 6, 6, 1});
m.PopulateTensor<float>(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18});
@@ -216,9 +146,8 @@ TEST_P(TransposeConvOpTest, PaddingValidTest) {
// tf.constant(np.arange(1, 5), shape=[ 1, 2, 2, 1 ], dtype=tf.float32),
// [1, 2, 2, 1 ],
// "VALID")
-TEST_P(TransposeConvOpTest, StrideValidTest) {
- TransposeConvOpModel m(GetRegistration(), {1, 2, 2, 1}, {1, 3, 3, 1},
- Padding_VALID, 2, 2);
+TEST(TransposeConvOpModelTest, StrideValidTest) {
+ TransposeConvOpModel m({1, 2, 2, 1}, {1, 3, 3, 1}, Padding_VALID, 2, 2);
m.PopulateTensor<int>(m.output_shape(), {1, 5, 5, 1});
m.PopulateTensor<float>(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9});
m.PopulateTensor<float>(m.input(), {1, 2, 3, 4});
@@ -241,9 +170,8 @@ TEST_P(TransposeConvOpTest, StrideValidTest) {
// tf.constant(np.arange(1, 5), shape=[ 1, 2, 2, 1 ], dtype=tf.float32),
// [1, 2, 2, 1 ],
// "VALID")
-TEST_P(TransposeConvOpTest, MultiChannelTest) {
- TransposeConvOpModel m(GetRegistration(), {1, 2, 2, 1}, {2, 3, 3, 1},
- Padding_VALID, 2, 2);
+TEST(TransposeConvOpModelTest, MultiChannelTest) {
+ TransposeConvOpModel m({1, 2, 2, 1}, {2, 3, 3, 1}, Padding_VALID, 2, 2);
m.PopulateTensor<int>(m.output_shape(), {1, 5, 5, 2});
m.PopulateTensor<float>(m.filter(), {1, 3, 5, 7, 9, 11, 13, 15, 17, 2, 4, 6,
8, 10, 12, 14, 16, 18});
@@ -259,24 +187,6 @@ TEST_P(TransposeConvOpTest, MultiChannelTest) {
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 5, 5, 2}));
}
-// Test case: Same as above, but with a const "output_shape"
-TEST_P(TransposeConvOpTest, ConstMultiChannelTest) {
- ConstTransposeConvOpModel m(GetRegistration(), {1, 2, 2, 1}, {2, 3, 3, 1},
- {1, 5, 5, 2}, Padding_VALID, 2, 2);
- m.PopulateTensor<float>(m.filter(), {1, 3, 5, 7, 9, 11, 13, 15, 17, 2, 4, 6,
- 8, 10, 12, 14, 16, 18});
- m.PopulateTensor<float>(m.input(), {1, 2, 3, 4});
- m.Invoke();
-
- EXPECT_THAT(
- m.GetOutput(),
- ElementsAreArray({1, 2, 3, 4, 7, 10, 6, 8, 10, 12, 7, 8, 9,
- 10, 25, 28, 18, 20, 22, 24, 16, 20, 24, 28, 62, 72,
- 42, 48, 54, 60, 21, 24, 27, 30, 61, 68, 36, 40, 44,
- 48, 39, 42, 45, 48, 103, 110, 60, 64, 68, 72}));
- EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 5, 5, 2}));
-}
-
// Test case:
// filter = tf.constant(np.random.randint(1, 10, size=9),
// shape=[ 3, 3, 1, 1 ],
@@ -289,9 +199,8 @@ TEST_P(TransposeConvOpTest, ConstMultiChannelTest) {
// "SAME")
// And filter value is derived by:
// filter = tf.reshape(tf.transpose(filter, perm=[3, 0, 1, 2]), shape=[-1])
-TEST_P(TransposeConvOpTest, AccuracyTest) {
- TransposeConvOpModel m(GetRegistration(), {1, 1, 2, 1}, {1, 3, 3, 1},
- Padding_SAME, 3, 3);
+TEST(TransposeConvOpModelTest, AccuracyTest) {
+ TransposeConvOpModel m({1, 1, 2, 1}, {1, 3, 3, 1}, Padding_SAME, 3, 3);
m.PopulateTensor<int>(m.output_shape(), {1, 3, 4, 1});
m.PopulateTensor<float>(m.filter(), {9, 5, 6, 9, 8, 5, 3, 1, 4});
m.PopulateTensor<float>(m.input(), {323, 521});
@@ -303,10 +212,6 @@ TEST_P(TransposeConvOpTest, AccuracyTest) {
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 4, 1}));
}
-INSTANTIATE_TEST_CASE_P(
- TransposeConvOpTest, TransposeConvOpTest,
- ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap)));
-
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
index 32daf2bb02..c48b470f92 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
@@ -274,7 +274,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const int n_output = recurrent_to_output_weights->dims->data[1];
// Check that input tensor dimensions matches with each other.
- CheckInputTensorDimensions(context, node, n_input, n_output, n_cell);
+ TF_LITE_ENSURE_OK(context, CheckInputTensorDimensions(context, node, n_input,
+ n_output, n_cell));
// Get the pointer to output, output_state and cell_state buffer tensors.
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index c448fb71db..71e38c3f13 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -186,6 +186,8 @@ InterpreterBuilder::InterpreterBuilder(const ::tflite::Model* model,
op_resolver_(op_resolver),
error_reporter_(ValidateErrorReporter(error_reporter)) {}
+InterpreterBuilder::~InterpreterBuilder() {}
+
TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() {
TfLiteStatus status = kTfLiteOk;
auto opcodes = model_->operator_codes();
@@ -204,8 +206,9 @@ TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() {
} else if (builtin_code != BuiltinOperator_CUSTOM) {
registration = op_resolver_.FindOp(builtin_code, version);
if (registration == nullptr) {
- error_reporter_->Report("Didn't find op for builtin opcode '%s'\n",
- EnumNameBuiltinOperator(builtin_code));
+ error_reporter_->Report(
+ "Didn't find op for builtin opcode '%s' version '%d'\n",
+ EnumNameBuiltinOperator(builtin_code), version);
status = kTfLiteError;
}
} else if (!opcode->custom_code()) {
@@ -661,6 +664,15 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = reinterpret_cast<void*>(params);
break;
}
+ case BuiltinOperator_ARG_MIN: {
+ auto* params = MallocPOD<TfLiteArgMinParams>();
+ if (const auto* schema_params = op->builtin_options_as_ArgMinOptions()) {
+ ConvertTensorType(schema_params->output_type(), &params->output_type,
+ error_reporter);
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
case BuiltinOperator_TRANSPOSE_CONV: {
TfLiteTransposeConvParams* params =
MallocPOD<TfLiteTransposeConvParams>();
@@ -697,6 +709,17 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
error_reporter->Report("DELEGATE op shouldn't exist in model.");
return kTfLiteError;
}
+ case BuiltinOperator_FAKE_QUANT: {
+ auto* params = MallocPOD<TfLiteFakeQuantParams>();
+ if (auto* schema_params = op->builtin_options_as_FakeQuantOptions()) {
+ params->min = schema_params->min();
+ params->max = schema_params->max();
+ params->num_bits = schema_params->num_bits();
+ params->narrow_range = schema_params->narrow_range();
+ }
+ *builtin_data = static_cast<void*>(params);
+ break;
+ }
// Below are the ops with no builtin_data strcture.
case BuiltinOperator_BATCH_TO_SPACE_ND:
diff --git a/tensorflow/contrib/lite/model.h b/tensorflow/contrib/lite/model.h
index 3946b49041..8bc9ecd7ce 100644
--- a/tensorflow/contrib/lite/model.h
+++ b/tensorflow/contrib/lite/model.h
@@ -156,6 +156,7 @@ class InterpreterBuilder {
InterpreterBuilder(const ::tflite::Model* model,
const OpResolver& op_resolver,
ErrorReporter* error_reporter = DefaultErrorReporter());
+ ~InterpreterBuilder();
InterpreterBuilder(const InterpreterBuilder&) = delete;
InterpreterBuilder& operator=(const InterpreterBuilder&) = delete;
TfLiteStatus operator()(std::unique_ptr<Interpreter>* interpreter);
diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc
index 905c0919cb..cc668485a4 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/nnapi_delegate.cc
@@ -548,6 +548,18 @@ TfLiteStatus AddOpsAndParams(
add_squeeze_params(node.builtin_data);
nn_op_type = ANEURALNETWORKS_SQUEEZE;
break;
+ case tflite::BuiltinOperator_TRANSPOSE:
+ // The permutation input tensor value dictates the output dimensions.
+ // TODO(b/110888333): Support dynamically-sized tensors in delegates.
+ if ((node.inputs->size > 1) &&
+ (interpreter->tensor(node.inputs->data[1])->allocation_type !=
+ kTfLiteMmapRo)) {
+ logError("NNAPI does not yet support dynamic tensors.");
+ return kTfLiteError;
+ }
+ nnapi_version = 11; // require NNAPI 1.1
+ nn_op_type = ANEURALNETWORKS_TRANSPOSE;
+ break;
case tflite::BuiltinOperator_CONCAT_EMBEDDINGS:
case tflite::BuiltinOperator_LSH_PROJECTION:
case tflite::BuiltinOperator_HASHTABLE_LOOKUP:
@@ -567,7 +579,6 @@ TfLiteStatus AddOpsAndParams(
case tflite::BuiltinOperator_SPACE_TO_BATCH_ND:
case tflite::BuiltinOperator_BATCH_TO_SPACE_ND:
case tflite::BuiltinOperator_TOPK_V2:
- case tflite::BuiltinOperator_TRANSPOSE:
case tflite::BuiltinOperator_SPLIT:
case tflite::BuiltinOperator_STRIDED_SLICE:
case tflite::BuiltinOperator_EXP:
@@ -579,6 +590,7 @@ TfLiteStatus AddOpsAndParams(
case tflite::BuiltinOperator_MAXIMUM:
case tflite::BuiltinOperator_MINIMUM:
case tflite::BuiltinOperator_ARG_MAX:
+ case tflite::BuiltinOperator_ARG_MIN:
case tflite::BuiltinOperator_GREATER:
case tflite::BuiltinOperator_GREATER_EQUAL:
case tflite::BuiltinOperator_LESS:
@@ -599,6 +611,7 @@ TfLiteStatus AddOpsAndParams(
case tflite::BuiltinOperator_RSQRT:
case tflite::BuiltinOperator_SHAPE:
case tflite::BuiltinOperator_POW:
+ case tflite::BuiltinOperator_FAKE_QUANT:
logError("Op code %d is currently not delegated to NNAPI", builtin);
return kTfLiteError;
break;
diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD
index 27909a9458..8c9608db04 100644
--- a/tensorflow/contrib/lite/python/BUILD
+++ b/tensorflow/contrib/lite/python/BUILD
@@ -19,6 +19,7 @@ py_library(
visibility = ["//visibility:public"],
deps = [
"//tensorflow/contrib/lite/python/interpreter_wrapper:tensorflow_wrap_interpreter_wrapper",
+ "//tensorflow/python:util",
],
)
@@ -30,9 +31,10 @@ py_test(
tags = ["no_oss"],
deps = [
":interpreter",
- "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:platform_test",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform",
+ "//third_party/py/numpy",
],
)
diff --git a/tensorflow/contrib/lite/python/interpreter.py b/tensorflow/contrib/lite/python/interpreter.py
index fd90823425..e1981ceae2 100644
--- a/tensorflow/contrib/lite/python/interpreter.py
+++ b/tensorflow/contrib/lite/python/interpreter.py
@@ -56,9 +56,6 @@ class Interpreter(object):
self._interpreter = (
_interpreter_wrapper.InterpreterWrapper_CreateWrapperCPPFromBuffer(
model_content))
- if not self._interpreter:
- raise ValueError(
- 'Failed to create model from {} bytes'.format(len(model_content)))
elif not model_path and not model_path:
raise ValueError('`model_path` or `model_content` must be specified.')
else:
@@ -66,8 +63,7 @@ class Interpreter(object):
def allocate_tensors(self):
self._ensure_safe()
- if not self._interpreter.AllocateTensors():
- raise ValueError('Failed to allocate tensors')
+ return self._interpreter.AllocateTensors()
def _safe_to_run(self):
"""Returns true if there exist no numpy array buffers.
@@ -152,8 +148,7 @@ class Interpreter(object):
Raises:
ValueError: If the interpreter could not set the tensor.
"""
- if not self._interpreter.SetTensor(tensor_index, value):
- raise ValueError('Failed to set tensor')
+ self._interpreter.SetTensor(tensor_index, value)
def resize_tensor_input(self, input_index, tensor_size):
"""Resizes an input tensor.
@@ -167,8 +162,7 @@ class Interpreter(object):
ValueError: If the interpreter could not resize the input tensor.
"""
self._ensure_safe()
- if not self._interpreter.ResizeInputTensor(input_index, tensor_size):
- raise ValueError('Failed to resize input')
+ self._interpreter.ResizeInputTensor(input_index, tensor_size)
def get_output_details(self):
"""Gets model output details.
@@ -181,7 +175,9 @@ class Interpreter(object):
]
def get_tensor(self, tensor_index):
- """Gets the value of the input tensor. Note this makes a copy so prefer `tensor()`.
+ """Gets the value of the input tensor (get a copy).
+
+ If you wish to avoid the copy, use `tensor()`.
Args:
tensor_index: Tensor index of tensor to get. This value can be gotten from
@@ -247,5 +243,7 @@ class Interpreter(object):
ValueError: When the underlying interpreter fails raise ValueError.
"""
self._ensure_safe()
- if not self._interpreter.Invoke():
- raise ValueError('Failed to invoke TFLite model')
+ self._interpreter.Invoke()
+
+ def reset_all_variables_to_zero(self):
+ return self._interpreter.ResetVariableTensorsToZero()
diff --git a/tensorflow/contrib/lite/python/interpreter_test.py b/tensorflow/contrib/lite/python/interpreter_test.py
index 5f1fa26c3b..95fa4b8584 100644
--- a/tensorflow/contrib/lite/python/interpreter_test.py
+++ b/tensorflow/contrib/lite/python/interpreter_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import io
import numpy as np
+import six
from tensorflow.contrib.lite.python import interpreter as interpreter_wrapper
from tensorflow.python.framework import test_util
@@ -91,6 +92,28 @@ class InterpreterTest(test_util.TensorFlowTestCase):
self.assertTrue((expected_output == output_data).all())
+class InterpreterTestErrorPropagation(test_util.TensorFlowTestCase):
+
+ def testInvalidModelContent(self):
+ with self.assertRaisesRegexp(ValueError,
+ 'Model provided has model identifier \''):
+ interpreter_wrapper.Interpreter(model_content=six.b('garbage'))
+
+ def testInvalidModelFile(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'Could not open \'totally_invalid_file_name\''):
+ interpreter_wrapper.Interpreter(
+ model_path='totally_invalid_file_name')
+
+ def testInvokeBeforeReady(self):
+ interpreter = interpreter_wrapper.Interpreter(
+ model_path=resource_loader.get_path_to_datafile(
+ 'testdata/permute_float.tflite'))
+ with self.assertRaisesRegexp(RuntimeError,
+ 'Invoke called on model that is not ready'):
+ interpreter.invoke()
+
+
class InterpreterTensorAccessorTest(test_util.TensorFlowTestCase):
def setUp(self):
diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD b/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD
index 634c2a1e1f..69ee95c320 100644
--- a/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD
+++ b/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD
@@ -13,7 +13,6 @@ cc_library(
deps = [
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite/kernels:builtin_ops",
- "//tensorflow/core:lib",
"//third_party/py/numpy:headers",
"//third_party/python_runtime:headers",
"@com_google_absl//absl/memory",
diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
index 5554d08fa0..c38b692dcd 100644
--- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
+++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
@@ -14,13 +14,13 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h"
+#include <sstream>
#include <string>
#include "absl/memory/memory.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/model.h"
-#include "tensorflow/core/platform/logging.h"
// Disallow Numpy 1.7 deprecated symbols.
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
@@ -38,9 +38,58 @@ limitations under the License.
#define CPP_TO_PYSTRING PyString_FromStringAndSize
#endif
+#define TFLITE_PY_CHECK(x) \
+ if ((x) != kTfLiteOk) { \
+ return error_reporter_->exception(); \
+ }
+
+#define TFLITE_PY_TENSOR_BOUNDS_CHECK(i) \
+ if (i >= interpreter_->tensors_size() || i < 0) { \
+ PyErr_Format(PyExc_ValueError, \
+ "Invalid tensor index %d exceeds max tensor index %lu", i, \
+ interpreter_->tensors_size()); \
+ return nullptr; \
+ }
+
+#define TFLITE_PY_ENSURE_VALID_INTERPRETER() \
+ if (!interpreter_) { \
+ PyErr_SetString(PyExc_ValueError, "Interpreter was not initialized."); \
+ return nullptr; \
+ }
+
namespace tflite {
namespace interpreter_wrapper {
+class PythonErrorReporter : public tflite::ErrorReporter {
+ public:
+ PythonErrorReporter() {}
+
+ // Report an error message
+ int Report(const char* format, va_list args) override {
+ char buf[1024];
+ int formatted = vsnprintf(buf, sizeof(buf), format, args);
+ buffer_ << buf;
+ return formatted;
+ }
+
+ // Set's a Python runtime exception with the last error.
+ PyObject* exception() {
+ std::string last_message = message();
+ PyErr_SetString(PyExc_RuntimeError, last_message.c_str());
+ return nullptr;
+ }
+
+ // Gets the last error message and clears the buffer.
+ std::string message() {
+ std::string value = buffer_.str();
+ buffer_.clear();
+ return value;
+ }
+
+ private:
+ std::stringstream buffer_;
+};
+
namespace {
// Calls PyArray's initialization to initialize all the API pointers. Note that
@@ -60,19 +109,6 @@ std::unique_ptr<tflite::Interpreter> CreateInterpreter(
std::unique_ptr<tflite::Interpreter> interpreter;
tflite::InterpreterBuilder(*model, resolver)(&interpreter);
- if (interpreter) {
- for (const int input_index : interpreter->inputs()) {
- const TfLiteTensor* tensor = interpreter->tensor(input_index);
- CHECK(tensor);
- const TfLiteIntArray* dims = tensor->dims;
- if (!dims) {
- continue;
- }
-
- std::vector<int> input_dims(dims->data, dims->data + dims->size);
- interpreter->ResizeInputTensor(input_index, input_dims);
- }
- }
return interpreter;
}
@@ -95,10 +131,10 @@ int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) {
case kTfLiteComplex64:
return NPY_COMPLEX64;
case kTfLiteNoType:
- return -1;
+ return NPY_NOTYPE;
+ // Avoid default so compiler errors created when new types are made.
}
- LOG(ERROR) << "Unknown TfLiteType " << tf_lite_type;
- return -1;
+ return NPY_NOTYPE;
}
TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) {
@@ -122,8 +158,8 @@ TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) {
return kTfLiteString;
case NPY_COMPLEX64:
return kTfLiteComplex64;
+ // Avoid default so compiler errors created when new types are made.
}
- LOG(ERROR) << "Unknown PyArray dtype " << pyarray_type;
return kTfLiteNoType;
}
@@ -147,32 +183,29 @@ PyObject* PyTupleFromQuantizationParam(const TfLiteQuantizationParams& param) {
} // namespace
InterpreterWrapper::InterpreterWrapper(
- std::unique_ptr<tflite::FlatBufferModel> model)
+ std::unique_ptr<tflite::FlatBufferModel> model,
+ std::unique_ptr<PythonErrorReporter> error_reporter)
: model_(std::move(model)),
+ error_reporter_(std::move(error_reporter)),
resolver_(absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>()),
interpreter_(CreateInterpreter(model_.get(), *resolver_)) {}
InterpreterWrapper::~InterpreterWrapper() {}
-bool InterpreterWrapper::AllocateTensors() {
- if (!interpreter_) {
- LOG(ERROR) << "Cannot allocate tensors: invalid interpreter.";
- return false;
- }
-
- if (interpreter_->AllocateTensors() != kTfLiteOk) {
- LOG(ERROR) << "Unable to allocate tensors.";
- return false;
- }
-
- return true;
+PyObject* InterpreterWrapper::AllocateTensors() {
+ TFLITE_PY_ENSURE_VALID_INTERPRETER();
+ TFLITE_PY_CHECK(interpreter_->AllocateTensors());
+ Py_RETURN_NONE;
}
-bool InterpreterWrapper::Invoke() {
- return interpreter_ ? (interpreter_->Invoke() == kTfLiteOk) : false;
+PyObject* InterpreterWrapper::Invoke() {
+ TFLITE_PY_ENSURE_VALID_INTERPRETER();
+ TFLITE_PY_CHECK(interpreter_->Invoke());
+ Py_RETURN_NONE;
}
PyObject* InterpreterWrapper::InputIndices() const {
+ TFLITE_PY_ENSURE_VALID_INTERPRETER();
PyObject* np_array = PyArrayFromIntVector(interpreter_->inputs().data(),
interpreter_->inputs().size());
@@ -186,35 +219,36 @@ PyObject* InterpreterWrapper::OutputIndices() const {
return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
}
-bool InterpreterWrapper::ResizeInputTensor(int i, PyObject* value) {
- if (!interpreter_) {
- LOG(ERROR) << "Invalid interpreter.";
- return false;
- }
+PyObject* InterpreterWrapper::ResizeInputTensor(int i, PyObject* value) {
+ TFLITE_PY_ENSURE_VALID_INTERPRETER();
std::unique_ptr<PyObject, PyDecrefDeleter> array_safe(
PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr));
if (!array_safe) {
- LOG(ERROR) << "Failed to convert value into readable tensor.";
- return false;
+ PyErr_SetString(PyExc_ValueError,
+ "Failed to convert numpy value into readable tensor.");
+ return nullptr;
}
PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get());
if (PyArray_NDIM(array) != 1) {
- LOG(ERROR) << "Expected 1-D defining input shape.";
- return false;
+ PyErr_Format(PyExc_ValueError, "Shape should be 1D instead of %d.",
+ PyArray_NDIM(array));
+ return nullptr;
}
if (PyArray_TYPE(array) != NPY_INT32) {
- LOG(ERROR) << "Shape must be an int32 array";
- return false;
+ PyErr_Format(PyExc_ValueError, "Shape must be type int32 (was %d).",
+ PyArray_TYPE(array));
+ return nullptr;
}
std::vector<int> dims(PyArray_SHAPE(array)[0]);
memcpy(dims.data(), PyArray_BYTES(array), dims.size() * sizeof(int));
- return (interpreter_->ResizeInputTensor(i, dims) == kTfLiteOk);
+ TFLITE_PY_CHECK(interpreter_->ResizeInputTensor(i, dims));
+ Py_RETURN_NONE;
}
std::string InterpreterWrapper::TensorName(int i) const {
@@ -227,21 +261,21 @@ std::string InterpreterWrapper::TensorName(int i) const {
}
PyObject* InterpreterWrapper::TensorType(int i) const {
- if (!interpreter_ || i >= interpreter_->tensors_size() || i < 0) {
- return nullptr;
- }
+ TFLITE_PY_ENSURE_VALID_INTERPRETER();
+ TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
const TfLiteTensor* tensor = interpreter_->tensor(i);
- int typenum = TfLiteTypeToPyArrayType(tensor->type);
- return PyArray_TypeObjectFromType(typenum);
+ int code = TfLiteTypeToPyArrayType(tensor->type);
+ if (code == -1) {
+ PyErr_Format(PyExc_ValueError, "Invalid tflite type code %d", code);
+ return nullptr;
+ }
+ return PyArray_TypeObjectFromType(code);
}
PyObject* InterpreterWrapper::TensorSize(int i) const {
- if (!interpreter_ || i >= interpreter_->tensors_size() || i < 0) {
- Py_INCREF(Py_None);
- return Py_None;
- }
-
+ TFLITE_PY_ENSURE_VALID_INTERPRETER();
+ TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
const TfLiteTensor* tensor = interpreter_->tensor(i);
PyObject* np_array =
PyArrayFromIntVector(tensor->dims->data, tensor->dims->size);
@@ -250,97 +284,82 @@ PyObject* InterpreterWrapper::TensorSize(int i) const {
}
PyObject* InterpreterWrapper::TensorQuantization(int i) const {
- if (!interpreter_ || i >= interpreter_->tensors_size() || i < 0) {
- Py_INCREF(Py_None);
- return Py_None;
- }
-
+ TFLITE_PY_ENSURE_VALID_INTERPRETER();
+ TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
const TfLiteTensor* tensor = interpreter_->tensor(i);
return PyTupleFromQuantizationParam(tensor->params);
}
-bool InterpreterWrapper::SetTensor(int i, PyObject* value) {
- if (!interpreter_) {
- LOG(ERROR) << "Invalid interpreter.";
- return false;
- }
-
- if (i >= interpreter_->tensors_size()) {
- LOG(ERROR) << "Invalid tensor index: " << i << " exceeds max tensor index "
- << interpreter_->tensors_size();
- return false;
- }
+PyObject* InterpreterWrapper::SetTensor(int i, PyObject* value) {
+ TFLITE_PY_ENSURE_VALID_INTERPRETER();
+ TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
std::unique_ptr<PyObject, PyDecrefDeleter> array_safe(
PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr));
if (!array_safe) {
- LOG(ERROR) << "Failed to convert value into readable tensor.";
- return false;
+ PyErr_SetString(PyExc_ValueError,
+ "Failed to convert value into readable tensor.");
+ return nullptr;
}
PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get());
const TfLiteTensor* tensor = interpreter_->tensor(i);
if (TfLiteTypeFromPyArray(array) != tensor->type) {
- LOG(ERROR) << "Cannot set tensor:"
- << " Got tensor of type " << TfLiteTypeFromPyArray(array)
- << " but expected type " << tensor->type << " for input " << i;
- return false;
+ PyErr_Format(PyExc_ValueError,
+ "Cannot set tensor:"
+ " Got tensor of type %d"
+ " but expected type %d for input %d ",
+ TfLiteTypeFromPyArray(array), tensor->type, i);
+ return nullptr;
}
if (PyArray_NDIM(array) != tensor->dims->size) {
- LOG(ERROR) << "Cannot set tensor: Dimension mismatch";
- return false;
+ PyErr_SetString(PyExc_ValueError, "Cannot set tensor: Dimension mismatch");
+ return nullptr;
}
for (int j = 0; j < PyArray_NDIM(array); j++) {
if (tensor->dims->data[j] != PyArray_SHAPE(array)[j]) {
- LOG(ERROR) << "Cannot set tensor: Dimension mismatch";
- return false;
+ PyErr_SetString(PyExc_ValueError,
+ "Cannot set tensor: Dimension mismatch");
+ return nullptr;
}
}
size_t size = PyArray_NBYTES(array);
- DCHECK_EQ(size, tensor->bytes);
+ if (size != tensor->bytes) {
+ PyErr_Format(PyExc_ValueError,
+ "numpy array had %zu bytes but expected %zu bytes.", size,
+ tensor->bytes);
+ return nullptr;
+ }
memcpy(tensor->data.raw, PyArray_DATA(array), size);
- return true;
+ Py_RETURN_NONE;
}
namespace {
-PyObject* CheckGetTensorArgs(Interpreter* interpreter, int tensor_index,
+PyObject* CheckGetTensorArgs(Interpreter* interpreter_, int tensor_index,
TfLiteTensor** tensor, int* type_num) {
- if (!interpreter) {
- LOG(ERROR) << "Invalid interpreter.";
- Py_INCREF(Py_None);
- return Py_None;
- }
-
- if (tensor_index >= interpreter->tensors_size() || tensor_index < 0) {
- LOG(ERROR) << "Invalid tensor index: " << tensor_index
- << " exceeds max tensor index " << interpreter->inputs().size();
- Py_INCREF(Py_None);
- return Py_None;
- }
+ TFLITE_PY_ENSURE_VALID_INTERPRETER();
+ TFLITE_PY_TENSOR_BOUNDS_CHECK(tensor_index);
- *tensor = interpreter->tensor(tensor_index);
+ *tensor = interpreter_->tensor(tensor_index);
if ((*tensor)->bytes == 0) {
- LOG(ERROR) << "Invalid tensor size";
- Py_INCREF(Py_None);
- return Py_None;
+ PyErr_SetString(PyExc_ValueError, "Invalid tensor size.");
+ return nullptr;
}
*type_num = TfLiteTypeToPyArrayType((*tensor)->type);
if (*type_num == -1) {
- LOG(ERROR) << "Unknown tensor type " << (*tensor)->type;
- Py_INCREF(Py_None);
- return Py_None;
+ PyErr_SetString(PyExc_ValueError, "Unknown tensor type.");
+ return nullptr;
}
if (!(*tensor)->data.raw) {
- LOG(ERROR) << "Tensor data is null.";
- Py_INCREF(Py_None);
- return Py_None;
+ PyErr_SetString(PyExc_ValueError, "Tensor data is null.");
+ return nullptr;
}
return nullptr;
@@ -362,9 +381,8 @@ PyObject* InterpreterWrapper::GetTensor(int i) const {
// it will leak.
void* data = malloc(tensor->bytes);
if (!data) {
- LOG(ERROR) << "Malloc to copy tensor failed.";
- Py_INCREF(Py_None);
- return Py_None;
+ PyErr_SetString(PyExc_ValueError, "Malloc to copy tensor failed.");
+ return nullptr;
}
memcpy(data, tensor->data.raw, tensor->bytes);
PyObject* np_array =
@@ -394,22 +412,39 @@ PyObject* InterpreterWrapper::tensor(PyObject* base_object, int i) {
}
InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile(
- const char* model_path) {
+ const char* model_path, std::string* error_msg) {
+ std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
std::unique_ptr<tflite::FlatBufferModel> model =
- tflite::FlatBufferModel::BuildFromFile(model_path);
- return model ? new InterpreterWrapper(std::move(model)) : nullptr;
+ tflite::FlatBufferModel::BuildFromFile(model_path, error_reporter.get());
+ if (!model) {
+ *error_msg = error_reporter->message();
+ return nullptr;
+ }
+ return new InterpreterWrapper(std::move(model), std::move(error_reporter));
}
InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
- PyObject* data) {
+ PyObject* data, std::string* error_msg) {
char * buf = nullptr;
Py_ssize_t length;
+ std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
if (PY_TO_CPPSTRING(data, &buf, &length) == -1) {
return nullptr;
}
std::unique_ptr<tflite::FlatBufferModel> model =
- tflite::FlatBufferModel::BuildFromBuffer(buf, length);
- return model ? new InterpreterWrapper(std::move(model)) : nullptr;
+ tflite::FlatBufferModel::BuildFromBuffer(buf, length,
+ error_reporter.get());
+ if (!model) {
+ *error_msg = error_reporter->message();
+ return nullptr;
+ }
+ return new InterpreterWrapper(std::move(model), std::move(error_reporter));
+}
+
+PyObject* InterpreterWrapper::ResetVariableTensorsToZero() {
+ TFLITE_PY_ENSURE_VALID_INTERPRETER();
+ TFLITE_PY_CHECK(interpreter_->ResetVariableTensorsToZero());
+ Py_RETURN_NONE;
}
} // namespace interpreter_wrapper
diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h
index 681448be20..556ec7117a 100644
--- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h
+++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h
@@ -15,13 +15,13 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_PYTHON_INTERPRETER_WRAPPER_INTERPRETER_WRAPPER_H_
#define TENSORFLOW_CONTRIB_LITE_PYTHON_INTERPRETER_WRAPPER_INTERPRETER_WRAPPER_H_
+// Place `<locale>` before <Python.h> to avoid build failures in macOS.
+#include <locale>
#include <memory>
#include <string>
#include <vector>
-// Place `<locale>` before <Python.h> to avoid build failures in macOS.
#include <Python.h>
-#include <locale>
// We forward declare TFLite classes here to avoid exposing them to SWIG.
namespace tflite {
@@ -36,34 +36,41 @@ class Interpreter;
namespace interpreter_wrapper {
+class PythonErrorReporter;
+
class InterpreterWrapper {
public:
// SWIG caller takes ownership of pointer.
- static InterpreterWrapper* CreateWrapperCPPFromFile(const char* model_path);
+ static InterpreterWrapper* CreateWrapperCPPFromFile(const char* model_path,
+ std::string* error_msg);
// SWIG caller takes ownership of pointer.
- static InterpreterWrapper* CreateWrapperCPPFromBuffer(PyObject* data);
+ static InterpreterWrapper* CreateWrapperCPPFromBuffer(PyObject* data,
+ std::string* error_msg);
~InterpreterWrapper();
- bool AllocateTensors();
- bool Invoke();
+ PyObject* AllocateTensors();
+ PyObject* Invoke();
PyObject* InputIndices() const;
PyObject* OutputIndices() const;
- bool ResizeInputTensor(int i, PyObject* value);
+ PyObject* ResizeInputTensor(int i, PyObject* value);
std::string TensorName(int i) const;
PyObject* TensorType(int i) const;
PyObject* TensorSize(int i) const;
PyObject* TensorQuantization(int i) const;
- bool SetTensor(int i, PyObject* value);
+ PyObject* SetTensor(int i, PyObject* value);
PyObject* GetTensor(int i) const;
+ PyObject* ResetVariableTensorsToZero();
+
// Returns a reference to tensor index i as a numpy array. The base_object
// should be the interpreter object providing the memory.
PyObject* tensor(PyObject* base_object, int i);
private:
- InterpreterWrapper(std::unique_ptr<tflite::FlatBufferModel> model);
+ InterpreterWrapper(std::unique_ptr<tflite::FlatBufferModel> model,
+ std::unique_ptr<PythonErrorReporter> error_reporter);
// InterpreterWrapper is not copyable or assignable. We avoid the use of
// InterpreterWrapper() = delete here for SWIG compatibility.
@@ -71,6 +78,7 @@ class InterpreterWrapper {
InterpreterWrapper(const InterpreterWrapper& rhs);
const std::unique_ptr<tflite::FlatBufferModel> model_;
+ const std::unique_ptr<PythonErrorReporter> error_reporter_;
const std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver_;
const std::unique_ptr<tflite::Interpreter> interpreter_;
};
diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.i b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.i
index 7f51f9f00d..afb2092eac 100644
--- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.i
+++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.i
@@ -18,8 +18,51 @@ limitations under the License.
%{
#define SWIG_FILE_WITH_INIT
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/model.h"
#include "tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h"
%}
%include "tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h"
+
+namespace tflite {
+namespace interpreter_wrapper {
+%extend InterpreterWrapper {
+
+ // Version of the constructor that handles producing Python exceptions
+ // that propagate strings.
+ static PyObject* CreateWrapperCPPFromFile(const char* model_path) {
+ std::string error;
+ if(tflite::interpreter_wrapper::InterpreterWrapper* ptr =
+ tflite::interpreter_wrapper::InterpreterWrapper
+ ::CreateWrapperCPPFromFile(
+ model_path, &error)) {
+ return SWIG_NewPointerObj(
+ ptr, SWIGTYPE_p_tflite__interpreter_wrapper__InterpreterWrapper, 1);
+ } else {
+ PyErr_SetString(PyExc_ValueError, error.c_str());
+ return nullptr;
+ }
+ }
+
+ // Version of the constructor that handles producing Python exceptions
+ // that propagate strings.
+ static PyObject* CreateWrapperCPPFromBuffer(
+ PyObject* data) {
+ std::string error;
+ if(tflite::interpreter_wrapper::InterpreterWrapper* ptr =
+ tflite::interpreter_wrapper::InterpreterWrapper
+ ::CreateWrapperCPPFromBuffer(
+ data, &error)) {
+ return SWIG_NewPointerObj(
+ ptr, SWIGTYPE_p_tflite__interpreter_wrapper__InterpreterWrapper, 1);
+ } else {
+ PyErr_SetString(PyExc_ValueError, error.c_str());
+ return nullptr;
+ }
+ }
+}
+
+} // namespace interpreter_wrapper
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs
index 15fb8bbdb8..64830b1dc3 100644
--- a/tensorflow/contrib/lite/schema/schema.fbs
+++ b/tensorflow/contrib/lite/schema/schema.fbs
@@ -44,7 +44,7 @@ enum TensorType : byte {
table QuantizationParameters {
min:[float]; // For importing back into tensorflow.
max:[float]; // For importing back into tensorflow.
- scale:[float];
+ scale:[float]; // For dequantizing the tensor's values.
zero_point:[long];
}
@@ -160,6 +160,8 @@ enum BuiltinOperator : byte {
RSQRT = 76,
SHAPE = 77,
POW = 78,
+ ARG_MIN = 79,
+ FAKE_QUANT = 80,
}
// Options for the builtin operators.
@@ -220,6 +222,8 @@ union BuiltinOptions {
NotEqualOptions,
ShapeOptions,
PowOptions,
+ ArgMinOptions,
+ FakeQuantOptions,
}
enum Padding : byte { SAME, VALID }
@@ -469,6 +473,10 @@ table ArgMaxOptions {
output_type : TensorType;
}
+table ArgMinOptions {
+ output_type : TensorType;
+}
+
table GreaterOptions {
}
@@ -517,6 +525,16 @@ table ShapeOptions {
table PowOptions {
}
+table FakeQuantOptions {
+ // Parameters supported by version 1:
+ min:float;
+ max:float;
+ num_bits:int;
+
+ // Parameters supported by version 2:
+ narrow_range:bool;
+}
+
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
// builtin, or a string if the operator is custom.
table OperatorCode {
diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h
index fe0ff9a7a5..c0b57039cb 100755
--- a/tensorflow/contrib/lite/schema/schema_generated.h
+++ b/tensorflow/contrib/lite/schema/schema_generated.h
@@ -157,6 +157,9 @@ struct TileOptionsT;
struct ArgMaxOptions;
struct ArgMaxOptionsT;
+struct ArgMinOptions;
+struct ArgMinOptionsT;
+
struct GreaterOptions;
struct GreaterOptionsT;
@@ -199,6 +202,9 @@ struct ShapeOptionsT;
struct PowOptions;
struct PowOptionsT;
+struct FakeQuantOptions;
+struct FakeQuantOptionsT;
+
struct OperatorCode;
struct OperatorCodeT;
@@ -343,11 +349,13 @@ enum BuiltinOperator {
BuiltinOperator_RSQRT = 76,
BuiltinOperator_SHAPE = 77,
BuiltinOperator_POW = 78,
+ BuiltinOperator_ARG_MIN = 79,
+ BuiltinOperator_FAKE_QUANT = 80,
BuiltinOperator_MIN = BuiltinOperator_ADD,
- BuiltinOperator_MAX = BuiltinOperator_POW
+ BuiltinOperator_MAX = BuiltinOperator_FAKE_QUANT
};
-inline BuiltinOperator (&EnumValuesBuiltinOperator())[78] {
+inline BuiltinOperator (&EnumValuesBuiltinOperator())[80] {
static BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
@@ -426,7 +434,9 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[78] {
BuiltinOperator_SQRT,
BuiltinOperator_RSQRT,
BuiltinOperator_SHAPE,
- BuiltinOperator_POW
+ BuiltinOperator_POW,
+ BuiltinOperator_ARG_MIN,
+ BuiltinOperator_FAKE_QUANT
};
return values;
}
@@ -512,6 +522,8 @@ inline const char **EnumNamesBuiltinOperator() {
"RSQRT",
"SHAPE",
"POW",
+ "ARG_MIN",
+ "FAKE_QUANT",
nullptr
};
return names;
@@ -580,11 +592,13 @@ enum BuiltinOptions {
BuiltinOptions_NotEqualOptions = 54,
BuiltinOptions_ShapeOptions = 55,
BuiltinOptions_PowOptions = 56,
+ BuiltinOptions_ArgMinOptions = 57,
+ BuiltinOptions_FakeQuantOptions = 58,
BuiltinOptions_MIN = BuiltinOptions_NONE,
- BuiltinOptions_MAX = BuiltinOptions_PowOptions
+ BuiltinOptions_MAX = BuiltinOptions_FakeQuantOptions
};
-inline BuiltinOptions (&EnumValuesBuiltinOptions())[57] {
+inline BuiltinOptions (&EnumValuesBuiltinOptions())[59] {
static BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
@@ -642,7 +656,9 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[57] {
BuiltinOptions_EqualOptions,
BuiltinOptions_NotEqualOptions,
BuiltinOptions_ShapeOptions,
- BuiltinOptions_PowOptions
+ BuiltinOptions_PowOptions,
+ BuiltinOptions_ArgMinOptions,
+ BuiltinOptions_FakeQuantOptions
};
return values;
}
@@ -706,6 +722,8 @@ inline const char **EnumNamesBuiltinOptions() {
"NotEqualOptions",
"ShapeOptions",
"PowOptions",
+ "ArgMinOptions",
+ "FakeQuantOptions",
nullptr
};
return names;
@@ -944,6 +962,14 @@ template<> struct BuiltinOptionsTraits<PowOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_PowOptions;
};
+template<> struct BuiltinOptionsTraits<ArgMinOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_ArgMinOptions;
+};
+
+template<> struct BuiltinOptionsTraits<FakeQuantOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_FakeQuantOptions;
+};
+
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@@ -1423,6 +1449,22 @@ struct BuiltinOptionsUnion {
return type == BuiltinOptions_PowOptions ?
reinterpret_cast<const PowOptionsT *>(value) : nullptr;
}
+ ArgMinOptionsT *AsArgMinOptions() {
+ return type == BuiltinOptions_ArgMinOptions ?
+ reinterpret_cast<ArgMinOptionsT *>(value) : nullptr;
+ }
+ const ArgMinOptionsT *AsArgMinOptions() const {
+ return type == BuiltinOptions_ArgMinOptions ?
+ reinterpret_cast<const ArgMinOptionsT *>(value) : nullptr;
+ }
+ FakeQuantOptionsT *AsFakeQuantOptions() {
+ return type == BuiltinOptions_FakeQuantOptions ?
+ reinterpret_cast<FakeQuantOptionsT *>(value) : nullptr;
+ }
+ const FakeQuantOptionsT *AsFakeQuantOptions() const {
+ return type == BuiltinOptions_FakeQuantOptions ?
+ reinterpret_cast<const FakeQuantOptionsT *>(value) : nullptr;
+ }
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@@ -4486,6 +4528,60 @@ inline flatbuffers::Offset<ArgMaxOptions> CreateArgMaxOptions(
flatbuffers::Offset<ArgMaxOptions> CreateArgMaxOptions(flatbuffers::FlatBufferBuilder &_fbb, const ArgMaxOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct ArgMinOptionsT : public flatbuffers::NativeTable {
+ typedef ArgMinOptions TableType;
+ TensorType output_type;
+ ArgMinOptionsT()
+ : output_type(TensorType_FLOAT32) {
+ }
+};
+
+struct ArgMinOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef ArgMinOptionsT NativeTableType;
+ enum {
+ VT_OUTPUT_TYPE = 4
+ };
+ TensorType output_type() const {
+ return static_cast<TensorType>(GetField<int8_t>(VT_OUTPUT_TYPE, 0));
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int8_t>(verifier, VT_OUTPUT_TYPE) &&
+ verifier.EndTable();
+ }
+ ArgMinOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(ArgMinOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<ArgMinOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const ArgMinOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct ArgMinOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_output_type(TensorType output_type) {
+ fbb_.AddElement<int8_t>(ArgMinOptions::VT_OUTPUT_TYPE, static_cast<int8_t>(output_type), 0);
+ }
+ explicit ArgMinOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ ArgMinOptionsBuilder &operator=(const ArgMinOptionsBuilder &);
+ flatbuffers::Offset<ArgMinOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<ArgMinOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<ArgMinOptions> CreateArgMinOptions(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ TensorType output_type = TensorType_FLOAT32) {
+ ArgMinOptionsBuilder builder_(_fbb);
+ builder_.add_output_type(output_type);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<ArgMinOptions> CreateArgMinOptions(flatbuffers::FlatBufferBuilder &_fbb, const ArgMinOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct GreaterOptionsT : public flatbuffers::NativeTable {
typedef GreaterOptions TableType;
GreaterOptionsT() {
@@ -5112,6 +5208,96 @@ inline flatbuffers::Offset<PowOptions> CreatePowOptions(
flatbuffers::Offset<PowOptions> CreatePowOptions(flatbuffers::FlatBufferBuilder &_fbb, const PowOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct FakeQuantOptionsT : public flatbuffers::NativeTable {
+ typedef FakeQuantOptions TableType;
+ float min;
+ float max;
+ int32_t num_bits;
+ bool narrow_range;
+ FakeQuantOptionsT()
+ : min(0.0f),
+ max(0.0f),
+ num_bits(0),
+ narrow_range(false) {
+ }
+};
+
+struct FakeQuantOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef FakeQuantOptionsT NativeTableType;
+ enum {
+ VT_MIN = 4,
+ VT_MAX = 6,
+ VT_NUM_BITS = 8,
+ VT_NARROW_RANGE = 10
+ };
+ float min() const {
+ return GetField<float>(VT_MIN, 0.0f);
+ }
+ float max() const {
+ return GetField<float>(VT_MAX, 0.0f);
+ }
+ int32_t num_bits() const {
+ return GetField<int32_t>(VT_NUM_BITS, 0);
+ }
+ bool narrow_range() const {
+ return GetField<uint8_t>(VT_NARROW_RANGE, 0) != 0;
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<float>(verifier, VT_MIN) &&
+ VerifyField<float>(verifier, VT_MAX) &&
+ VerifyField<int32_t>(verifier, VT_NUM_BITS) &&
+ VerifyField<uint8_t>(verifier, VT_NARROW_RANGE) &&
+ verifier.EndTable();
+ }
+ FakeQuantOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(FakeQuantOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<FakeQuantOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const FakeQuantOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct FakeQuantOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_min(float min) {
+ fbb_.AddElement<float>(FakeQuantOptions::VT_MIN, min, 0.0f);
+ }
+ void add_max(float max) {
+ fbb_.AddElement<float>(FakeQuantOptions::VT_MAX, max, 0.0f);
+ }
+ void add_num_bits(int32_t num_bits) {
+ fbb_.AddElement<int32_t>(FakeQuantOptions::VT_NUM_BITS, num_bits, 0);
+ }
+ void add_narrow_range(bool narrow_range) {
+ fbb_.AddElement<uint8_t>(FakeQuantOptions::VT_NARROW_RANGE, static_cast<uint8_t>(narrow_range), 0);
+ }
+ explicit FakeQuantOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ FakeQuantOptionsBuilder &operator=(const FakeQuantOptionsBuilder &);
+ flatbuffers::Offset<FakeQuantOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<FakeQuantOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<FakeQuantOptions> CreateFakeQuantOptions(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ float min = 0.0f,
+ float max = 0.0f,
+ int32_t num_bits = 0,
+ bool narrow_range = false) {
+ FakeQuantOptionsBuilder builder_(_fbb);
+ builder_.add_num_bits(num_bits);
+ builder_.add_max(max);
+ builder_.add_min(min);
+ builder_.add_narrow_range(narrow_range);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<FakeQuantOptions> CreateFakeQuantOptions(flatbuffers::FlatBufferBuilder &_fbb, const FakeQuantOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct OperatorCodeT : public flatbuffers::NativeTable {
typedef OperatorCode TableType;
BuiltinOperator builtin_code;
@@ -5413,6 +5599,12 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const PowOptions *builtin_options_as_PowOptions() const {
return builtin_options_type() == BuiltinOptions_PowOptions ? static_cast<const PowOptions *>(builtin_options()) : nullptr;
}
+ const ArgMinOptions *builtin_options_as_ArgMinOptions() const {
+ return builtin_options_type() == BuiltinOptions_ArgMinOptions ? static_cast<const ArgMinOptions *>(builtin_options()) : nullptr;
+ }
+ const FakeQuantOptions *builtin_options_as_FakeQuantOptions() const {
+ return builtin_options_type() == BuiltinOptions_FakeQuantOptions ? static_cast<const FakeQuantOptions *>(builtin_options()) : nullptr;
+ }
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
@@ -5668,6 +5860,14 @@ template<> inline const PowOptions *Operator::builtin_options_as<PowOptions>() c
return builtin_options_as_PowOptions();
}
+template<> inline const ArgMinOptions *Operator::builtin_options_as<ArgMinOptions>() const {
+ return builtin_options_as_ArgMinOptions();
+}
+
+template<> inline const FakeQuantOptions *Operator::builtin_options_as<FakeQuantOptions>() const {
+ return builtin_options_as_FakeQuantOptions();
+}
+
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@@ -7333,6 +7533,32 @@ inline flatbuffers::Offset<ArgMaxOptions> CreateArgMaxOptions(flatbuffers::FlatB
_output_type);
}
+inline ArgMinOptionsT *ArgMinOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new ArgMinOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void ArgMinOptions::UnPackTo(ArgMinOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = output_type(); _o->output_type = _e; };
+}
+
+inline flatbuffers::Offset<ArgMinOptions> ArgMinOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ArgMinOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateArgMinOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<ArgMinOptions> CreateArgMinOptions(flatbuffers::FlatBufferBuilder &_fbb, const ArgMinOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ArgMinOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _output_type = _o->output_type;
+ return tflite::CreateArgMinOptions(
+ _fbb,
+ _output_type);
+}
+
inline GreaterOptionsT *GreaterOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new GreaterOptionsT();
UnPackTo(_o, _resolver);
@@ -7670,6 +7896,41 @@ inline flatbuffers::Offset<PowOptions> CreatePowOptions(flatbuffers::FlatBufferB
_fbb);
}
+inline FakeQuantOptionsT *FakeQuantOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new FakeQuantOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void FakeQuantOptions::UnPackTo(FakeQuantOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = min(); _o->min = _e; };
+ { auto _e = max(); _o->max = _e; };
+ { auto _e = num_bits(); _o->num_bits = _e; };
+ { auto _e = narrow_range(); _o->narrow_range = _e; };
+}
+
+inline flatbuffers::Offset<FakeQuantOptions> FakeQuantOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const FakeQuantOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateFakeQuantOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<FakeQuantOptions> CreateFakeQuantOptions(flatbuffers::FlatBufferBuilder &_fbb, const FakeQuantOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const FakeQuantOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _min = _o->min;
+ auto _max = _o->max;
+ auto _num_bits = _o->num_bits;
+ auto _narrow_range = _o->narrow_range;
+ return tflite::CreateFakeQuantOptions(
+ _fbb,
+ _min,
+ _max,
+ _num_bits,
+ _narrow_range);
+}
+
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new OperatorCodeT();
UnPackTo(_o, _resolver);
@@ -8083,6 +8344,14 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
auto ptr = reinterpret_cast<const PowOptions *>(obj);
return verifier.VerifyTable(ptr);
}
+ case BuiltinOptions_ArgMinOptions: {
+ auto ptr = reinterpret_cast<const ArgMinOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_FakeQuantOptions: {
+ auto ptr = reinterpret_cast<const FakeQuantOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
default: return false;
}
}
@@ -8325,6 +8594,14 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
auto ptr = reinterpret_cast<const PowOptions *>(obj);
return ptr->UnPack(resolver);
}
+ case BuiltinOptions_ArgMinOptions: {
+ auto ptr = reinterpret_cast<const ArgMinOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
+ case BuiltinOptions_FakeQuantOptions: {
+ auto ptr = reinterpret_cast<const FakeQuantOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
default: return nullptr;
}
}
@@ -8555,6 +8832,14 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
auto ptr = reinterpret_cast<const PowOptionsT *>(value);
return CreatePowOptions(_fbb, ptr, _rehasher).Union();
}
+ case BuiltinOptions_ArgMinOptions: {
+ auto ptr = reinterpret_cast<const ArgMinOptionsT *>(value);
+ return CreateArgMinOptions(_fbb, ptr, _rehasher).Union();
+ }
+ case BuiltinOptions_FakeQuantOptions: {
+ auto ptr = reinterpret_cast<const FakeQuantOptionsT *>(value);
+ return CreateFakeQuantOptions(_fbb, ptr, _rehasher).Union();
+ }
default: return 0;
}
}
@@ -8785,6 +9070,14 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
value = new PowOptionsT(*reinterpret_cast<PowOptionsT *>(u.value));
break;
}
+ case BuiltinOptions_ArgMinOptions: {
+ value = new ArgMinOptionsT(*reinterpret_cast<ArgMinOptionsT *>(u.value));
+ break;
+ }
+ case BuiltinOptions_FakeQuantOptions: {
+ value = new FakeQuantOptionsT(*reinterpret_cast<FakeQuantOptionsT *>(u.value));
+ break;
+ }
default:
break;
}
@@ -9072,6 +9365,16 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
+ case BuiltinOptions_ArgMinOptions: {
+ auto ptr = reinterpret_cast<ArgMinOptionsT *>(value);
+ delete ptr;
+ break;
+ }
+ case BuiltinOptions_FakeQuantOptions: {
+ auto ptr = reinterpret_cast<FakeQuantOptionsT *>(value);
+ delete ptr;
+ break;
+ }
default: break;
}
value = nullptr;
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index 50237ed792..1093bd2cbe 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -678,6 +678,55 @@ def make_relu6_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+def make_prelu_tests(zip_path):
+ """Make a set of tests to do PReLU."""
+
+ test_parameters = [{
+ # The canonical case for image processing is having a 4D `input` (NHWC)
+ # and `shared_axes`=[1, 2], so the alpha parameter is per channel.
+ "input_shape": [[1, 10, 10, 3], [3, 3, 3, 3]],
+ "shared_axes": [[1, 2], [1]],
+ }]
+
+ def build_graph(parameters):
+ """Build the graph for the test case."""
+
+ input_tensor = tf.placeholder(
+ dtype=tf.float32, name="input", shape=parameters["input_shape"])
+ prelu = tf.keras.layers.PReLU(shared_axes=parameters["shared_axes"])
+ out = prelu(input_tensor)
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ """Build the inputs for the test case."""
+
+ input_shape = parameters["input_shape"]
+ input_values = create_tensor_data(
+ np.float32, input_shape, min_value=-10, max_value=10)
+ shared_axes = parameters["shared_axes"]
+
+ alpha_shape = []
+ for dim in range(1, len(input_shape)):
+ alpha_shape.append(1 if dim in shared_axes else input_shape[dim])
+
+ alpha_values = create_tensor_data(np.float32, alpha_shape)
+
+ # There should be only 1 trainable variable tensor.
+ variables = tf.all_variables()
+ assert len(variables) == 1
+ sess.run(variables[0].assign(alpha_values))
+
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(
+ zip_path,
+ test_parameters,
+ build_graph,
+ build_inputs,
+ use_frozen_graph=True)
+
+
# This function tests various TensorFLow functions that generates Const op,
# including `tf.ones`, `tf.zeros` and random functions.
def make_constant_tests(zip_path):
@@ -2175,7 +2224,7 @@ def make_topk_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
-def make_arg_max_tests(zip_path):
+def make_arg_min_max_tests(zip_path):
"""Make a set of tests to do arg_max."""
test_parameters = [{
@@ -2183,6 +2232,7 @@ def make_arg_max_tests(zip_path):
"input_shape": [[1, 1, 1, 3], [2, 3, 4, 5], [2, 3, 3], [5, 5], [10]],
"output_type": [tf.int32, tf.int64],
"axis_is_last_dim": [True, False],
+ "is_arg_max": [True],
}]
def build_graph(parameters):
@@ -2195,7 +2245,10 @@ def make_arg_max_tests(zip_path):
axis = len(parameters["input_shape"]) - 1
else:
axis = random.randint(0, max(len(parameters["input_shape"]) - 2, 0))
- out = tf.arg_max(input_value, axis, output_type=parameters["output_type"])
+ if parameters["is_arg_max"]:
+ out = tf.arg_max(input_value, axis, output_type=parameters["output_type"])
+ else:
+ out = tf.arg_min(input_value, axis, output_type=parameters["output_type"])
return [input_value], [out]
def build_inputs(parameters, sess, inputs, outputs):
diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
index c4e20312d8..58f6bb5382 100644
--- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
+++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
@@ -53,7 +53,6 @@ tensorflow::Env* env = tensorflow::Env::Default();
// Key is a substring of the test name and value is a bug number.
// TODO(ahentz): make sure we clean this list up frequently.
std::map<string, string> kBrokenTests = {
- {R"(^\/mul.*int32)", "68808744"},
{R"(^\/div.*int32)", "68808744"},
{R"(^\/sub.*int32)", "68808744"},
@@ -97,11 +96,12 @@ std::map<string, string> kBrokenTests = {
{R"(^\/gather.*axis=1)", "76910444"},
// No support for arbitrary dimensions in ArgMax.
- {R"(^\/arg_max.*axis_is_last_dim=False.*input_shape=\[.,.,.,.\])",
+ {R"(^\/arg_min_max.*axis_is_last_dim=False.*input_shape=\[.,.,.,.\])",
"77546240"},
- {R"(^\/arg_max.*axis_is_last_dim=False.*input_shape=\[.,.,.\])",
+ {R"(^\/arg_min_max.*axis_is_last_dim=False.*input_shape=\[.,.,.\])",
+ "77546240"},
+ {R"(^\/arg_min_max.*axis_is_last_dim=False.*input_shape=\[.,.\])",
"77546240"},
- {R"(^\/arg_max.*axis_is_last_dim=False.*input_shape=\[.,.\])", "77546240"},
};
// Allows test data to be unzipped into a temporary directory and makes
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD
index 209dce56cb..2c469c0e75 100644
--- a/tensorflow/contrib/lite/toco/BUILD
+++ b/tensorflow/contrib/lite/toco/BUILD
@@ -212,7 +212,7 @@ cc_library(
"graph_transformations/quantization_util.h",
"graph_transformations/quantize.cc",
"graph_transformations/quantize_weights.cc",
- "graph_transformations/read_fake_quant_min_max.cc",
+ "graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc",
"graph_transformations/remove_final_dequantize_op.cc",
"graph_transformations/remove_tensorflow_assert.cc",
"graph_transformations/remove_tensorflow_identity.cc",
@@ -245,6 +245,7 @@ cc_library(
"graph_transformations/resolve_constant_strided_slice.cc",
"graph_transformations/resolve_constant_transpose.cc",
"graph_transformations/resolve_constant_unary.cc",
+ "graph_transformations/resolve_fake_quant_args_from_vars.cc",
"graph_transformations/resolve_mean_attributes.cc",
"graph_transformations/resolve_multiply_by_zero.cc",
"graph_transformations/resolve_pad_attributes.cc",
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc
index 6be6b25f93..bf9a51a525 100644
--- a/tensorflow/contrib/lite/toco/export_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc
@@ -884,6 +884,9 @@ void ConvertFakeQuantOperator(const FakeQuantOperator& src_op,
if (src_op.num_bits) {
(*fakequant_op->mutable_attr())["num_bits"].set_i(src_op.num_bits);
}
+ if (src_op.narrow_range) {
+ (*fakequant_op->mutable_attr())["narrow_range"].set_b(src_op.narrow_range);
+ }
}
void ConvertMaxPoolOperator(const MaxPoolOperator& src_op,
@@ -1135,6 +1138,22 @@ void ConvertArgMaxOperator(const Model& model, const ArgMaxOperator& src_op,
GetTensorFlowDataType(model, src_op.outputs[0]));
}
+void ConvertArgMinOperator(const Model& model, const ArgMinOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ tensorflow::NodeDef* argmin_op = tensorflow_graph->add_node();
+ argmin_op->set_op("ArgMin");
+ argmin_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 2);
+ *argmin_op->add_input() = src_op.inputs[0];
+ *argmin_op->add_input() = src_op.inputs[1];
+ (*argmin_op->mutable_attr())["T"].set_type(
+ GetTensorFlowDataType(model, src_op.inputs[0]));
+ (*argmin_op->mutable_attr())["Tidx"].set_type(
+ GetTensorFlowDataType(model, src_op.inputs[1]));
+ (*argmin_op->mutable_attr())["output_type"].set_type(
+ GetTensorFlowDataType(model, src_op.outputs[0]));
+}
+
void ConvertTransposeOperator(const Model& model,
const TransposeOperator& src_op,
GraphDef* tensorflow_graph) {
@@ -1964,6 +1983,9 @@ void ConvertOperator(const Model& model, const Operator& src_op,
} else if (src_op.type == OperatorType::kArgMax) {
ConvertArgMaxOperator(model, static_cast<const ArgMaxOperator&>(src_op),
tensorflow_graph);
+ } else if (src_op.type == OperatorType::kArgMin) {
+ ConvertArgMinOperator(model, static_cast<const ArgMinOperator&>(src_op),
+ tensorflow_graph);
} else if (src_op.type == OperatorType::kTopK_V2) {
ConvertTopKV2Operator(model, static_cast<const TopKV2Operator&>(src_op),
tensorflow_graph);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc
index 2c7ffe4884..1688586733 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc
@@ -159,6 +159,7 @@ bool DequantizeArray(const string& array_name,
new_array.GetOrCreateMinMax() = array->GetMinMax();
fakequant_op->minmax.reset(new MinMax);
*fakequant_op->minmax = array->GetMinMax();
+ fakequant_op->narrow_range = array->narrow_range;
if (must_insert_fakequant_before) {
for (const auto& op : model->operators) {
for (string& output : op->outputs) {
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
index 8cd1298bca..7cc9bb75d7 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
+++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
@@ -159,7 +159,7 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveConstantBinaryOperator)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantUnaryOperator)
DECLARE_GRAPH_TRANSFORMATION(CreateIm2colArrays)
DECLARE_GRAPH_TRANSFORMATION(DropIm2colArrays)
-DECLARE_GRAPH_TRANSFORMATION(ReadFakeQuantMinMax)
+DECLARE_GRAPH_TRANSFORMATION(ReadArrayMinmaxAndNarrowRangeFromFakeQuant)
DECLARE_GRAPH_TRANSFORMATION(ReorderElementwiseUnary)
DECLARE_GRAPH_TRANSFORMATION(ReorderReshapeTranspose)
DECLARE_GRAPH_TRANSFORMATION(ResolveReorderAxes)
@@ -194,6 +194,7 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveMultiplyByZero)
DECLARE_GRAPH_TRANSFORMATION(Dequantize)
DECLARE_GRAPH_TRANSFORMATION(UnpartitionEmbeddingLookup)
DECLARE_GRAPH_TRANSFORMATION(ShuffleFCWeights)
+DECLARE_GRAPH_TRANSFORMATION(ResolveFakeQuantArgsFromVars)
class PropagateDefaultMinMax : public GraphTransformation {
public:
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc
index 30be4ac0aa..b90a156a0d 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc
@@ -74,14 +74,30 @@ bool IdentifyPRelu::Run(Model* model, std::size_t op_index) {
const auto* relu_neg_input_op = GetOpWithOutput(*model, mul_op->inputs[1]);
if (relu_neg_input_op == nullptr ||
- relu_neg_input_op->type != OperatorType::kNeg ||
- relu_neg_input_op->fused_activation_function !=
- FusedActivationFunctionType::kRelu ||
relu_neg_input_op->inputs.size() != 1) {
return false;
}
- if (relu_input_op->inputs[0] != relu_neg_input_op->inputs[0]) {
+ const Operator* final_input_op;
+ if (relu_neg_input_op->type == OperatorType::kNeg &&
+ relu_neg_input_op->fused_activation_function ==
+ FusedActivationFunctionType::kRelu) {
+ // This detects a Neg op with fused Relu activation function.
+ final_input_op = relu_neg_input_op;
+ } else {
+ // This detects a Neg op followed by a separated Relu op.
+ const auto* neg_input_op =
+ GetOpWithOutput(*model, relu_neg_input_op->inputs[0]);
+ if (neg_input_op == nullptr || neg_input_op->inputs.size() != 1 ||
+ relu_neg_input_op->type != OperatorType::kRelu ||
+ relu_neg_input_op->fused_activation_function !=
+ FusedActivationFunctionType::kNone) {
+ return false;
+ }
+ final_input_op = neg_input_op;
+ }
+
+ if (relu_input_op->inputs[0] != final_input_op->inputs[0]) {
return false;
}
@@ -112,7 +128,6 @@ bool IdentifyPRelu::Run(Model* model, std::size_t op_index) {
// intermediate tensors aren't used by other ops, those will be removed by
// other graph transformation rules.
model->operators.erase(FindOp(*model, add_op));
-
return true;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc b/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc
index 45d9f73a1e..f684de08ab 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc
@@ -85,15 +85,8 @@ bool AddDequantizeOperatorToInput(const string& input_name, const Operator* op,
dequantized_input_minmax = input_minmax;
auto& input_qparams = input_array.GetOrCreateQuantizationParams();
input_array.data_type = input_array.final_data_type;
- if (input_array.data_type == ArrayDataType::kUint8) {
- GetQuantizationParamsFromMinMax<ArrayDataType::kUint8>(input_minmax,
- &input_qparams);
- } else if (input_array.data_type == ArrayDataType::kInt16) {
- GetQuantizationParamsFromMinMax<ArrayDataType::kInt16>(input_minmax,
- &input_qparams);
- } else {
- LOG(FATAL) << "unhandled data type";
- }
+ ChooseQuantizationParamsForArrayAndQuantizedDataType(
+ input_array, input_array.data_type, &input_qparams);
transformation->AddMessageF(
"Created %s"
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
index 00ab7cbaa9..670bcf64e7 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
@@ -100,6 +100,13 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
model->GetArray(op->outputs[0]).data_type = argmax_op->output_data_type;
break;
}
+ case OperatorType::kArgMin: {
+ // Data type of the ArgMin op is specified.
+ CHECK_EQ(op->outputs.size(), 1);
+ auto* argmin_op = static_cast<ArgMinOperator*>(op);
+ model->GetArray(op->outputs[0]).data_type = argmin_op->output_data_type;
+ break;
+ }
case OperatorType::kRange: {
auto* range_op = static_cast<RangeOperator*>(op);
// Output type of the Range op can be set via an attribute
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc
index 0f2592d05f..3ad6b0ec6f 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc
@@ -30,15 +30,9 @@ namespace {
bool ChangeArrayDataType(GraphTransformation* transformation, Array* array,
ArrayDataType new_data_type,
const MinMax* new_minmax) {
- // The code below assumes kInt16, see
- // GetQuantizationParamsFromMinMax<ArrayDataType::kInt16>
- if (new_data_type != ArrayDataType::kInt16) {
- return false;
- }
-
- bool changed = false;
// Ensure the array ends up in the new type (if it hasn't yet been quantized).
- if ((array->final_data_type != new_data_type)) {
+ bool changed = false;
+ if (array->final_data_type != new_data_type) {
array->final_data_type = new_data_type;
changed = true;
}
@@ -72,12 +66,10 @@ bool ChangeArrayDataType(GraphTransformation* transformation, Array* array,
"Rescaling min/max from %g,%g (%s) to %g,%g (%s)", array_minmax.min,
array_minmax.max, ArrayDataTypeName(array->data_type), min, max,
ArrayDataTypeName(new_data_type));
-
array_minmax.min = min;
array_minmax.max = max;
- GetQuantizationParamsFromMinMax<ArrayDataType::kInt16>(
- array_minmax, array->quantization_params.get());
-
+ ChooseQuantizationParamsForArrayAndQuantizedDataType(
+ *array, new_data_type, array->quantization_params.get());
// Directly change the type as the array was already quantized.
array->data_type = new_data_type;
changed = true;
@@ -95,6 +87,7 @@ bool ChangeArrayDataType(GraphTransformation* transformation, Array* array,
changed = true;
}
}
+
return changed;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index 8eb0423283..4f95c57451 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -1404,7 +1404,8 @@ void ProcessTransposeOperator(Model* model, TransposeOperator* op) {
}
}
-void ProcessArgMaxOperator(Model* model, ArgMaxOperator* op) {
+template <typename Op>
+void ProcessArgMinMaxOperator(Model* model, Op* op) {
CHECK_EQ(op->inputs.size(), 2);
const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
@@ -1696,7 +1697,12 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
static_cast<StridedSliceOperator*>(op));
break;
case OperatorType::kArgMax:
- ProcessArgMaxOperator(model, static_cast<ArgMaxOperator*>(op));
+ ProcessArgMinMaxOperator<ArgMaxOperator>(
+ model, static_cast<ArgMaxOperator*>(op));
+ break;
+ case OperatorType::kArgMin:
+ ProcessArgMinMaxOperator<ArgMinOperator>(
+ model, static_cast<ArgMinOperator*>(op));
break;
case OperatorType::kUnsupported:
break;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.cc
index d74cad9a62..44733391f5 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.cc
@@ -74,46 +74,54 @@ ArrayDataType GetQuantizedDataType(const Array& array,
}
}
-void GetQuantizationParams(ArrayDataType data_type, const MinMax& minmax,
- QuantizationParams* quantization_params) {
- switch (data_type) {
+template <ArrayDataType A>
+void ChooseQuantizationParamsForArrayAndQuantizedDataType(
+ const Array& array, QuantizationParams* quantization_params) {
+ *quantization_params = ::tflite::ChooseQuantizationParams<DataType<A>>(
+ array.minmax->min, array.minmax->max, array.narrow_range);
+}
+
+void ChooseQuantizationParamsForArrayAndQuantizedDataType(
+ const Array& array, ArrayDataType quantized_data_type,
+ QuantizationParams* quantization_params) {
+ switch (quantized_data_type) {
case ArrayDataType::kInt8:
- GetQuantizationParamsFromMinMax<ArrayDataType::kInt8>(
- minmax, quantization_params);
+ ChooseQuantizationParamsForArrayAndQuantizedDataType<
+ ArrayDataType::kInt8>(array, quantization_params);
break;
case ArrayDataType::kUint8:
- GetQuantizationParamsFromMinMax<ArrayDataType::kUint8>(
- minmax, quantization_params);
+ ChooseQuantizationParamsForArrayAndQuantizedDataType<
+ ArrayDataType::kUint8>(array, quantization_params);
break;
case ArrayDataType::kInt16:
- GetQuantizationParamsFromMinMax<ArrayDataType::kInt16>(
- minmax, quantization_params);
+ ChooseQuantizationParamsForArrayAndQuantizedDataType<
+ ArrayDataType::kInt16>(array, quantization_params);
break;
case ArrayDataType::kUint16:
- GetQuantizationParamsFromMinMax<ArrayDataType::kUint16>(
- minmax, quantization_params);
+ ChooseQuantizationParamsForArrayAndQuantizedDataType<
+ ArrayDataType::kUint16>(array, quantization_params);
break;
case ArrayDataType::kInt32:
- GetQuantizationParamsFromMinMax<ArrayDataType::kInt32>(
- minmax, quantization_params);
+ ChooseQuantizationParamsForArrayAndQuantizedDataType<
+ ArrayDataType::kInt32>(array, quantization_params);
break;
case ArrayDataType::kUint32:
- GetQuantizationParamsFromMinMax<ArrayDataType::kUint32>(
- minmax, quantization_params);
+ ChooseQuantizationParamsForArrayAndQuantizedDataType<
+ ArrayDataType::kUint32>(array, quantization_params);
break;
case ArrayDataType::kInt64:
- GetQuantizationParamsFromMinMax<ArrayDataType::kInt64>(
- minmax, quantization_params);
+ ChooseQuantizationParamsForArrayAndQuantizedDataType<
+ ArrayDataType::kInt64>(array, quantization_params);
break;
case ArrayDataType::kUint64:
- GetQuantizationParamsFromMinMax<ArrayDataType::kUint64>(
- minmax, quantization_params);
+ ChooseQuantizationParamsForArrayAndQuantizedDataType<
+ ArrayDataType::kUint64>(array, quantization_params);
break;
case ArrayDataType::kFloat:
case ArrayDataType::kNone:
default:
LOG(FATAL) << "Unhandled final quantization type "
- << static_cast<int>(data_type);
+ << static_cast<int>(quantized_data_type);
}
}
@@ -121,8 +129,8 @@ namespace {
template <ArrayDataType A>
std::unique_ptr<GenericBuffer> QuantizeBuffer(
- const GenericBuffer& buffer,
- const QuantizationParams& quantization_params) {
+ const Array& array, const QuantizationParams& quantization_params) {
+ const GenericBuffer& buffer = *array.buffer;
const auto inverse_scale = 1. / quantization_params.scale;
CHECK(buffer.type == ArrayDataType::kFloat);
const auto& float_buffer =
@@ -140,8 +148,15 @@ std::unique_ptr<GenericBuffer> QuantizeBuffer(
} else {
scaled_val = quantization_params.zero_point + inverse_scale * src_val;
}
- quantized_buffer->data[i] =
- tflite::SafeCast<DataType<A>>(std::round(scaled_val));
+ auto integer_val = tflite::SafeCast<DataType<A>>(std::round(scaled_val));
+ // In addition to its effect on the choice of quantization params upstream
+ // of here, narrow_range also means nudge the min quantized value by +1,
+ // so e.g. uint8 values get constrained to [1, 255].
+ if (integer_val == std::numeric_limits<DataType<A>>::min() &&
+ array.narrow_range) {
+ integer_val++;
+ }
+ quantized_buffer->data[i] = integer_val;
}
return std::unique_ptr<GenericBuffer>(quantized_buffer);
}
@@ -155,7 +170,7 @@ void QuantizeArray(GraphTransformation* transformation, Model* model,
CHECK(!array.quantization_params);
array.GetOrCreateQuantizationParams() = quantization_params;
if (array.buffer) {
- array.buffer = QuantizeBuffer<A>(*array.buffer, quantization_params);
+ array.buffer = QuantizeBuffer<A>(array, quantization_params);
}
array.data_type = A;
array.final_data_type = A;
@@ -210,8 +225,8 @@ bool IsArrayQuantizedRangeSubset(GraphTransformation* transformation,
} else {
// Work around cases where we are asking for this prior to the Quantize
// transformation having added the quantization_params.
- GetQuantizationParams(quantized_data_type, *array.minmax,
- &quantization_params);
+ ChooseQuantizationParamsForArrayAndQuantizedDataType(
+ array, quantized_data_type, &quantization_params);
transformation->AddMessageF(
"No quantization params - infering from data type %s with minmax "
"%g,%g as zero_point=%g, scale=%g",
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h b/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h
index 79a2ce7e50..cf093c6f17 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h
+++ b/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h
@@ -38,21 +38,11 @@ bool GetQuantizedDataTypeNumericalRange(ArrayDataType data_type,
ArrayDataType GetQuantizedDataType(const Array& array,
ArrayDataType default_type);
-// Returns the quantization params for the array with the given data type and
-// minmax.
-void GetQuantizationParams(ArrayDataType data_type, const MinMax& minmax,
- QuantizationParams* quantization_params);
-
-// Returns the quantization params for the data type and minmax values.
-template <ArrayDataType A>
-void GetQuantizationParamsFromMinMax(const MinMax& minmax,
- QuantizationParams* quantization_params) {
- using Integer = DataType<A>;
- const double rmin = minmax.min;
- const double rmax = minmax.max;
- *quantization_params =
- ::tflite::ChooseQuantizationParams<Integer>(rmin, rmax);
-}
+// Chooses the quantization params for a given array and a given target
+// quantized data type (which may not be the array's current data type).
+void ChooseQuantizationParamsForArrayAndQuantizedDataType(
+ const Array& array, ArrayDataType quantized_data_type,
+ QuantizationParams* quantization_params);
// Quantizes an array by setting its data type and (if constant) quantizing
// all values in the array.
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
index 58885b4950..5be2757479 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
@@ -212,13 +212,15 @@ bool ChooseQuantizationForOperatorInput(
if (op.type == OperatorType::kLstmCell) {
if (input_index == LstmCellOperator::PREV_STATE_INPUT) {
*quantized_data_type = ArrayDataType::kInt16;
- GetQuantizationParams(*quantized_data_type, minmax, quantization_params);
+ ChooseQuantizationParamsForArrayAndQuantizedDataType(
+ array, *quantized_data_type, quantization_params);
return true;
}
}
*quantized_data_type = GetQuantizedDataType(array, ArrayDataType::kUint8);
- GetQuantizationParams(*quantized_data_type, minmax, quantization_params);
+ ChooseQuantizationParamsForArrayAndQuantizedDataType(
+ array, *quantized_data_type, quantization_params);
transformation->AddMessageF(
"For input array %s with min=%g, max=%g, chose to quantize as %s (f=%s) "
"with zero_point=%d, scale=%g",
@@ -358,12 +360,14 @@ bool ChooseQuantizationForOperatorOutput(
if (output_index == LstmCellOperator::STATE_OUTPUT ||
output_index == LstmCellOperator::ACTIV_TEMP) {
*quantized_data_type = ArrayDataType::kInt16;
- GetQuantizationParams(*quantized_data_type, minmax, quantization_params);
+ ChooseQuantizationParamsForArrayAndQuantizedDataType(
+ array, *quantized_data_type, quantization_params);
return true;
}
}
*quantized_data_type = GetQuantizedDataType(array, ArrayDataType::kUint8);
- GetQuantizationParams(*quantized_data_type, minmax, quantization_params);
+ ChooseQuantizationParamsForArrayAndQuantizedDataType(
+ array, *quantized_data_type, quantization_params);
transformation->AddMessageF(
"For output array %s with min=%g, max=%g"
", chose to quantize as %s with zero_point=%d"
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc
index 88ea0945e7..7a8515f6d1 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc
@@ -36,10 +36,8 @@ void GetQuantizationParamsFromArray(const Array& array,
const std::vector<float>& float_vals =
array.GetBuffer<ArrayDataType::kFloat>().data;
auto minmax = std::minmax_element(float_vals.begin(), float_vals.end());
- MinMax toco_minmax;
- toco_minmax.min = *minmax.first;
- toco_minmax.max = *minmax.second;
- GetQuantizationParams(ArrayDataType::kUint8, toco_minmax, params);
+ *params = tflite::ChooseQuantizationParams<uint8>(
+ *minmax.first, *minmax.second, array.narrow_range);
}
} // namespace
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc
new file mode 100644
index 0000000000..5b41c49bfa
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc
@@ -0,0 +1,78 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+bool ApplyAttrsToArray(GraphTransformation* transformation, Model* model,
+ const FakeQuantOperator& fq_op,
+ const string& array_name) {
+ bool changed = false;
+ auto& annotated_array = model->GetArray(array_name);
+ if (!annotated_array.minmax) {
+ const MinMax& minmax = *fq_op.minmax;
+ annotated_array.GetOrCreateMinMax() = minmax;
+ transformation->AddMessageF(
+ "Read min/max annotation for array %s: min=%g, max=%g", array_name,
+ minmax.min, minmax.max);
+ changed = true;
+ }
+ if (fq_op.narrow_range && !annotated_array.narrow_range) {
+ annotated_array.narrow_range = true;
+ transformation->AddMessageF("Read narrow_range annotation for array %s",
+ array_name);
+ changed = true;
+ }
+ return changed;
+}
+
+} // end namespace
+
+bool ReadArrayMinmaxAndNarrowRangeFromFakeQuant::Run(Model* model,
+ std::size_t op_index) {
+ const auto fakequant_it = model->operators.begin() + op_index;
+ auto* fakequant_base_op = fakequant_it->get();
+ if (fakequant_base_op->type != OperatorType::kFakeQuant) {
+ return false;
+ }
+ auto* fq_op = static_cast<FakeQuantOperator*>(fakequant_base_op);
+
+ if (!fq_op->minmax) {
+ // Need to be resolved first by ResolveFakeQuantArgsFromVars.
+ return false;
+ }
+
+ // At this point, this FakeQuantOperator should have a MinMax
+ // attached to it, and should only have 1 input (it should not have
+ // 2nd and 3rd input arrays giving min and max anymore).
+ CHECK(fq_op->minmax);
+ CHECK_EQ(1, fq_op->inputs.size());
+
+ return ApplyAttrsToArray(this, model, *fq_op, fq_op->inputs[0]) ||
+ ApplyAttrsToArray(this, model, *fq_op, fq_op->outputs[0]);
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc
deleted file mode 100644
index bdcca5b7ca..0000000000
--- a/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc
+++ /dev/null
@@ -1,112 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include <algorithm>
-#include <memory>
-#include <string>
-#include <unordered_map>
-#include <vector>
-
-#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
-#include "tensorflow/contrib/lite/toco/model.h"
-#include "tensorflow/contrib/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
-
-namespace toco {
-
-namespace {
-
-bool ApplyMinMaxToArray(GraphTransformation* transformation, Model* model,
- const MinMax& minmax, const string& array_name) {
- auto& annotated_array = model->GetArray(array_name);
- if (annotated_array.minmax) {
- return false;
- }
- annotated_array.GetOrCreateMinMax() = minmax;
- transformation->AddMessageF(
- "Read min/max annotation for array %s: min=%g, max=%g", array_name,
- minmax.min, minmax.max);
- return true;
-}
-
-} // end namespace
-
-bool ReadFakeQuantMinMax::Run(Model* model, std::size_t op_index) {
- const auto fakequant_it = model->operators.begin() + op_index;
- auto* fakequant_base_op = fakequant_it->get();
- if (fakequant_base_op->type != OperatorType::kFakeQuant) {
- return false;
- }
- auto* fakequant_op = static_cast<FakeQuantOperator*>(fakequant_base_op);
-
- bool changed = false;
-
- if (!fakequant_op->minmax) {
- CHECK_EQ(fakequant_op->inputs.size(), 3);
- // We need to yield until the min and max parameters have been
- // resolved to constant arrays.
- for (int i = 1; i <= 2; i++) {
- if (!IsConstantParameterArray(*model, fakequant_op->inputs[1])) {
- return false;
- }
- }
-
- // Obtain the final min/max values
- const auto& min_array = model->GetArray(fakequant_op->inputs[1]);
- const auto& max_array = model->GetArray(fakequant_op->inputs[2]);
- CHECK_EQ(RequiredBufferSizeForShape(min_array.shape()), 1);
- CHECK_EQ(RequiredBufferSizeForShape(max_array.shape()), 1);
- fakequant_op->minmax.reset(new MinMax);
- MinMax& minmax = *fakequant_op->minmax;
- minmax.min = min_array.GetBuffer<ArrayDataType::kFloat>().data[0];
- minmax.max = max_array.GetBuffer<ArrayDataType::kFloat>().data[0];
- // We always want [min, max] to contain 0.
- if (minmax.min > 0 || minmax.max < 0) {
- LOG(ERROR) << "For " << LogName(*fakequant_op) << " the MinMax range "
- << "[" << minmax.min << ", " << minmax.max
- << "] does not contain 0. "
- << "Proceeding by tweaking it to contain 0, which will result "
- "in poor accuracy.";
- }
- minmax.min = std::min(minmax.min, 0.);
- minmax.max = std::max(minmax.max, 0.);
-
- // We won't use the input arrays that provided these min and max
- // values, anymore. Delete them unless they are used by something
- // else.
- for (int i = 1; i <= 2; i++) {
- if (CountOpsWithInput(*model, fakequant_op->inputs[i]) == 1) {
- model->EraseArray(fakequant_op->inputs[i]);
- }
- }
- fakequant_op->inputs.resize(1);
- changed = true;
- }
-
- // At this point, this FakeQuantOperator should have a MinMax
- // attached to it, and should only have 1 input (it should not have
- // 2nd and 3rd input arrays giving min and max anymore).
- CHECK(fakequant_op->minmax);
- CHECK_EQ(1, fakequant_op->inputs.size());
-
- const MinMax& minmax = *fakequant_op->minmax;
-
- // Record the MinMax info on the input and output arrays
- changed |= ApplyMinMaxToArray(this, model, minmax, fakequant_op->inputs[0]);
- changed |= ApplyMinMaxToArray(this, model, minmax, fakequant_op->outputs[0]);
-
- return changed;
-}
-
-} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc
index efb7bb2184..058f314b33 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc
@@ -25,6 +25,37 @@ limitations under the License.
namespace toco {
+template <ArrayDataType A>
+void GetBoundsForQuantizedDataType(double* min, double* max) {
+ using limits = std::numeric_limits<DataType<A>>;
+ *min = limits::min();
+ *max = limits::max();
+}
+
+void GetBoundsForQuantizedDataType(ArrayDataType quantized_data_type,
+ double* min, double* max) {
+ switch (quantized_data_type) {
+ case ArrayDataType::kUint8:
+ return GetBoundsForQuantizedDataType<ArrayDataType::kUint8>(min, max);
+ case ArrayDataType::kInt8:
+ return GetBoundsForQuantizedDataType<ArrayDataType::kInt8>(min, max);
+ case ArrayDataType::kUint16:
+ return GetBoundsForQuantizedDataType<ArrayDataType::kUint16>(min, max);
+ case ArrayDataType::kInt16:
+ return GetBoundsForQuantizedDataType<ArrayDataType::kInt16>(min, max);
+ case ArrayDataType::kUint32:
+ return GetBoundsForQuantizedDataType<ArrayDataType::kUint32>(min, max);
+ case ArrayDataType::kInt32:
+ return GetBoundsForQuantizedDataType<ArrayDataType::kInt32>(min, max);
+ case ArrayDataType::kUint64:
+ return GetBoundsForQuantizedDataType<ArrayDataType::kUint64>(min, max);
+ case ArrayDataType::kInt64:
+ return GetBoundsForQuantizedDataType<ArrayDataType::kInt64>(min, max);
+ default:
+ LOG(FATAL) << "unhandled quantized data type";
+ }
+}
+
bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) {
const auto fakequant_it = model->operators.begin() + op_index;
const auto* fakequant_base_op = fakequant_it->get();
@@ -76,14 +107,21 @@ bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) {
const int size = input_buffer.data.size();
output_buffer.data.resize(size);
QuantizationParams qparams;
- GetQuantizationParamsFromMinMax<ArrayDataType::kUint8>(*fakequant_op->minmax,
- &qparams);
+ ChooseQuantizationParamsForArrayAndQuantizedDataType(
+ output_array, quantized_data_type, &qparams);
+ double quantized_min, quantized_max;
+ GetBoundsForQuantizedDataType(quantized_data_type, &quantized_min,
+ &quantized_max);
+ if (fakequant_op->narrow_range) {
+ quantized_min++;
+ }
+
for (int i = 0; i < size; i++) {
const double src_val = input_buffer.data[i];
const double unclamped_quantized_val =
std::round(qparams.zero_point + src_val / qparams.scale);
- const double quantized_val =
- std::min(255., std::max(0., unclamped_quantized_val));
+ const double quantized_val = std::min(
+ quantized_max, std::max(quantized_min, unclamped_quantized_val));
const double dst_val = qparams.scale * (quantized_val - qparams.zero_point);
output_buffer.data[i] = dst_val;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_fake_quant_args_from_vars.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_fake_quant_args_from_vars.cc
new file mode 100644
index 0000000000..0dda1fd0b3
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_fake_quant_args_from_vars.cc
@@ -0,0 +1,80 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolveFakeQuantArgsFromVars::Run(Model* model, std::size_t op_index) {
+ const auto fakequant_it = model->operators.begin() + op_index;
+ auto* fakequant_base_op = fakequant_it->get();
+ if (fakequant_base_op->type != OperatorType::kFakeQuant) {
+ return false;
+ }
+ auto* fakequant_op = static_cast<FakeQuantOperator*>(fakequant_base_op);
+
+ if (fakequant_op->minmax) {
+ // Already resolved.
+ return false;
+ }
+
+ CHECK_EQ(fakequant_op->inputs.size(), 3);
+ // We need to yield until the min and max parameters have been
+ // resolved to constant arrays.
+ for (int i = 1; i <= 2; i++) {
+ if (!IsConstantParameterArray(*model, fakequant_op->inputs[i])) {
+ return false;
+ }
+ }
+
+ // Obtain the final min/max values
+ const auto& min_array = model->GetArray(fakequant_op->inputs[1]);
+ const auto& max_array = model->GetArray(fakequant_op->inputs[2]);
+ CHECK_EQ(RequiredBufferSizeForShape(min_array.shape()), 1);
+ CHECK_EQ(RequiredBufferSizeForShape(max_array.shape()), 1);
+ fakequant_op->minmax.reset(new MinMax);
+ MinMax& minmax = *fakequant_op->minmax;
+ minmax.min = min_array.GetBuffer<ArrayDataType::kFloat>().data[0];
+ minmax.max = max_array.GetBuffer<ArrayDataType::kFloat>().data[0];
+ // We always want [min, max] to contain 0.
+ if (minmax.min > 0 || minmax.max < 0) {
+ LOG(ERROR) << "For " << LogName(*fakequant_op) << " the MinMax range "
+ << "[" << minmax.min << ", " << minmax.max
+ << "] does not contain 0. "
+ << "Proceeding by tweaking it to contain 0, which will result "
+ "in poor accuracy.";
+ }
+ minmax.min = std::min(minmax.min, 0.);
+ minmax.max = std::max(minmax.max, 0.);
+
+ // We won't use the input arrays that provided these min and max
+ // values, anymore. Delete them unless they are used by something
+ // else.
+ for (int i = 1; i <= 2; i++) {
+ DeleteArrayIfUsedOnce(fakequant_op->inputs[i], model);
+ }
+ fakequant_op->inputs.resize(1);
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 5c32a39035..ab3762e7ea 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -755,6 +755,9 @@ tensorflow::Status ConvertFakeQuantWithMinMaxArgs(
op->outputs.push_back(node.name());
// tf.fake_quant_with_min_max_args num_bits defaults to 8.
op->num_bits = HasAttr(node, "num_bits") ? GetIntAttr(node, "num_bits") : 8;
+ if (HasAttr(node, "narrow_range")) {
+ op->narrow_range = GetBoolAttr(node, "narrow_range");
+ }
model->operators.emplace_back(op);
return tensorflow::Status::OK();
}
@@ -774,6 +777,9 @@ tensorflow::Status ConvertFakeQuantWithMinMaxVars(
}
op->outputs.push_back(node.name());
op->num_bits = HasAttr(node, "num_bits") ? GetIntAttr(node, "num_bits") : 8;
+ if (HasAttr(node, "narrow_range")) {
+ op->narrow_range = GetBoolAttr(node, "narrow_range");
+ }
model->operators.emplace_back(op);
return tensorflow::Status::OK();
}
@@ -1230,10 +1236,11 @@ tensorflow::Status ConvertGatherOperator(
return tensorflow::Status::OK();
}
-tensorflow::Status ConvertArgMaxOperator(
+template <typename Op, const char* op_name>
+tensorflow::Status ConvertArgMinMaxOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
- CHECK_EQ(node.op(), "ArgMax");
+ CHECK_EQ(node.op(), op_name);
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
const auto axis_data_type =
HasAttr(node, "Tidx") ? GetDataTypeAttr(node, "Tidx") : DT_INT32;
@@ -1242,7 +1249,7 @@ tensorflow::Status ConvertArgMaxOperator(
: DT_INT64;
CHECK(axis_data_type == DT_INT64 || axis_data_type == DT_INT32);
CHECK(output_type == DT_INT64 || output_type == DT_INT32);
- auto* op = new ArgMaxOperator;
+ auto* op = new Op;
op->output_data_type = ConvertDataType(output_type);
op->inputs.push_back(node.input(0));
op->inputs.push_back(node.input(1));
@@ -1833,12 +1840,16 @@ using ConverterType = tensorflow::Status (*)(
Model* model);
using ConverterMapType = std::unordered_map<std::string, ConverterType>;
+constexpr char kArgMax[] = "ArgMax";
+constexpr char kArgMin[] = "ArgMin";
+
ConverterMapType GetTensorFlowNodeConverterMap() {
return std::unordered_map<std::string, ConverterType>({
{"Add", ConvertSimpleOperator<AddOperator, 2>},
{"AddN", ConvertSimpleOperator<AddNOperator>},
{"All", ConvertSimpleOperator<TensorFlowAllOperator>},
- {"ArgMax", ConvertArgMaxOperator},
+ {"ArgMax", ConvertArgMinMaxOperator<ArgMaxOperator, kArgMax>},
+ {"ArgMin", ConvertArgMinMaxOperator<ArgMinOperator, kArgMin>},
{"Assert", ConvertSimpleOperator<TensorFlowAssertOperator>},
{"AvgPool", ConvertAvgPoolOperator},
{"BatchMatMul", ConvertBatchMatMulOperator},
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index 3a1d243f87..d06a30b638 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -140,6 +140,7 @@ enum class OperatorType : uint8 {
kEqual,
kNotEqual,
kPow,
+ kArgMin,
};
// Helper to deal with TensorFlow arrays using a different ordering of
@@ -790,6 +791,7 @@ struct FakeQuantOperator : Operator {
FakeQuantOperator() : Operator(OperatorType::kFakeQuant) {}
std::unique_ptr<MinMax> minmax;
int num_bits = 8;
+ bool narrow_range = false;
};
// Element-wise division operator.
@@ -1528,6 +1530,17 @@ struct ArgMaxOperator : Operator {
ArrayDataType output_data_type = ArrayDataType::kInt64;
};
+// ArgMin operator. It returns the index of the minimum value along axis.
+//
+// Inputs:
+// inputs[0]: required: the input tensor
+//
+// TensorFlow equivalent: ArgMin
+struct ArgMinOperator : Operator {
+ ArgMinOperator() : Operator(OperatorType::kArgMin) {}
+ ArrayDataType output_data_type = ArrayDataType::kInt64;
+};
+
// ResizeBilinear operator. It resizes input images with bilinear interpolation.
// It does not support align_corners at the moment.
//
@@ -1842,6 +1855,40 @@ struct Array {
// If this is non-null, then these quantization parameters are to be used
// to assign a meaning as real numbers to the elements of this array.
std::unique_ptr<QuantizationParams> quantization_params;
+ // narrow_range is a detail of how toco handles FakeQuant operators with
+ // narrow_range, see
+ // https://www.tensorflow.org/api_docs/python/tf/fake_quant_with_min_max_vars
+ //
+ // For more context about what that is useful for, see the big comment in
+ // graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc
+ //
+ // The narrow_range flag applies only to quantized arrays, and changes
+ // their quantization in the following way when it is set to 'true':
+ // 1. The computation of {zero_point, scale} from {min, max} needs to be
+ // amended so that the real min value will get quantized to
+ // (min_quantized_value + 1) instead of just (min_quantized_value).
+ // E.g. for uint8 quantization, the real min value should get quantized to
+ // the uint8 value 1, not 0.
+ // 2. Quantized values should get clamped to the interval
+ // [min_quantized_value + 1, max_value]. Equivalently, the
+ // min_quantized_value should get nudged to (min_quantized_value + 1).
+ // The reason why 1. does not imply 2. is that real values may not belong to
+ // the stated [min, max] interval. Concretely, weights recorded at the last
+ // learning step may not fall in the [min, max] interval recorded over
+ // previous learning steps, as the values evolve across learning steps.
+ //
+ // Rationale why this is directly a field on Array:
+ // - This can't be just a field on FakeQuantOperator, because
+ // FakeQuantOperators are gone (DropFakeQuant) before we get to using that
+ // information (Quantize). We need a place to store that bit in the interim.
+ // - This can't be in QuantizationParams because we need to record this
+ // ahead of quantization, and QuantizationParams are only created during
+ // quantization.
+ // - This could be in MinMax, but that would be an abuse of what MinMax is
+ // about, and would break existing code that assumes that a MinMax is just
+ // a min and a max. Unlike MinMax which is agnostic as to the quantized
+ // data type, narrow_range refers to values in the quantized data type.
+ bool narrow_range = false;
private:
std::unique_ptr<Shape> array_shape;
diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc
index 1972246807..5ad307af14 100644
--- a/tensorflow/contrib/lite/toco/tflite/export.cc
+++ b/tensorflow/contrib/lite/toco/tflite/export.cc
@@ -336,17 +336,13 @@ void Export(
auto op_codes = ExportOperatorCodes(model, ops_by_type, operators_map,
&builder, &error_summary);
- const string fake_quant_operation_name = "FAKE_QUANT";
-
- if (error_summary.count(fake_quant_operation_name) != 0) {
- LOG(ERROR)
- << fake_quant_operation_name
- << " operation was not converted. If running quantized make sure you "
- "are passing --inference_type=QUANTIZED_UINT8 and values for "
- "--std_values and --mean_values.";
- // Remove the fake quant operation from the errors, since it shouldn't
- // be provided a custom implementation.
- error_summary.erase(fake_quant_operation_name);
+ for (const auto& op : model.operators) {
+ if (op->type == OperatorType::kFakeQuant) {
+ LOG(WARNING) << "FAKE_QUANT operation " << LogName(*op)
+ << " was not converted. If running quantized make sure you "
+ "are passing --inference_type=QUANTIZED_UINT8 and values "
+ "for --std_values and --mean_values.";
+ }
}
if (!allow_custom_ops && !error_summary.empty()) {
// Remove ExpandDims and ReorderAxes from unimplemented list unless they
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index 7e55ae92bd..a791e60f91 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -282,25 +282,31 @@ class DepthToSpace : public CustomOperator<DepthToSpaceOperator> {
int GetVersion(const Operator& op) const override { return 1; }
};
-class FakeQuant : public CustomOperator<FakeQuantOperator> {
+class FakeQuant
+ : public BuiltinOperator<FakeQuantOperator, ::tflite::FakeQuantOptions,
+ ::tflite::BuiltinOptions_FakeQuantOptions> {
public:
- using CustomOperator::CustomOperator;
- void WriteOptions(const TocoOperator& op,
- flexbuffers::Builder* fbb) const override {
- fbb->Float("min", op.minmax->min);
- fbb->Float("max", op.minmax->max);
- fbb->Int("num_bits", op.num_bits);
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ return ::tflite::CreateFakeQuantOptions(
+ *builder, op.minmax->min, op.minmax->max, op.num_bits, op.narrow_range);
}
- void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override {
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
auto* minmax = new MinMax;
- minmax->min = m["min"].AsFloat();
- minmax->max = m["max"].AsFloat();
+ minmax->min = options.min();
+ minmax->max = options.max();
op->minmax.reset(minmax);
- const auto& num_bits = m["num_bits"];
- op->num_bits = num_bits.IsInt() ? num_bits.AsInt32() : 8;
+ op->num_bits = options.num_bits();
+ op->narrow_range = options.narrow_range();
}
- int GetVersion(const Operator& op) const override { return 1; }
+ int GetVersion(const Operator& op) const override {
+ const auto& fq_op = static_cast<const FakeQuantOperator&>(op);
+ return fq_op.narrow_range ? 2 : 1;
+ }
};
class FullyConnected
@@ -885,6 +891,25 @@ class ArgMax : public BuiltinOperator<ArgMaxOperator, ::tflite::ArgMaxOptions,
int GetVersion(const Operator& op) const override { return 1; }
};
+class ArgMin : public BuiltinOperator<ArgMinOperator, ::tflite::ArgMinOptions,
+ ::tflite::BuiltinOptions_ArgMinOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ return ::tflite::CreateArgMinOptions(
+ *builder, DataType::Serialize(op.output_data_type));
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->output_data_type = DataType::Deserialize(options.output_type());
+ }
+
+ int GetVersion(const Operator& op) const override { return 1; }
+};
+
class TransposeConv
: public BuiltinOperator<TransposeConvOperator,
::tflite::TransposeConvOptions,
@@ -1175,6 +1200,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
ops.emplace_back(
new ArgMax(::tflite::BuiltinOperator_ARG_MAX, OperatorType::kArgMax));
ops.emplace_back(
+ new ArgMin(::tflite::BuiltinOperator_ARG_MIN, OperatorType::kArgMin));
+ ops.emplace_back(
new Tile(::tflite::BuiltinOperator_TILE, OperatorType::kTile));
ops.emplace_back(new ExpandDims(::tflite::BuiltinOperator_EXPAND_DIMS,
OperatorType::kExpandDims));
@@ -1184,11 +1211,12 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
OperatorType::kSparseToDense));
ops.emplace_back(
new Shape(::tflite::BuiltinOperator_SHAPE, OperatorType::kShape));
+ ops.emplace_back(new FakeQuant(::tflite::BuiltinOperator_FAKE_QUANT,
+ OperatorType::kFakeQuant));
// Custom Operators.
ops.emplace_back(
new DepthToSpace("DEPTH_TO_SPACE", OperatorType::kDepthToSpace));
- ops.emplace_back(new FakeQuant("FAKE_QUANT", OperatorType::kFakeQuant));
ops.emplace_back(new TensorFlowUnsupported("TENSORFLOW_UNSUPPORTED",
OperatorType::kUnsupported));
diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
index 8b6808d3c7..ff2d35b1f5 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
@@ -416,6 +416,13 @@ TEST_F(OperatorTest, BuiltinArgMax) {
EXPECT_EQ(op.output_data_type, output_toco_op->output_data_type);
}
+TEST_F(OperatorTest, BuiltinArgMin) {
+ ArgMinOperator op;
+ auto output_toco_op = SerializeAndDeserialize(
+ GetOperator("ARG_MIN", OperatorType::kArgMin), op);
+ EXPECT_EQ(op.output_data_type, output_toco_op->output_data_type);
+}
+
TEST_F(OperatorTest, BuiltinTransposeConv) {
TransposeConvOperator op;
op.stride_width = 123;
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc
index fc1636831b..a4dc1bbe93 100644
--- a/tensorflow/contrib/lite/toco/toco_tooling.cc
+++ b/tensorflow/contrib/lite/toco/toco_tooling.cc
@@ -105,7 +105,8 @@ void MakeGeneralGraphTransformationsSet(
transformations->Add(new IdentifyRelu1);
transformations->Add(new IdentifyPRelu);
transformations->Add(new RemoveTrivialBinaryOperator);
- transformations->Add(new ReadFakeQuantMinMax);
+ transformations->Add(new ResolveFakeQuantArgsFromVars);
+ transformations->Add(new ReadArrayMinmaxAndNarrowRangeFromFakeQuant);
transformations->Add(new ResolveSpaceToBatchNDAttributes);
transformations->Add(new ResolveBatchToSpaceNDAttributes);
transformations->Add(new ResolvePadAttributes);
@@ -273,13 +274,16 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
transformations.Add(new toco::MergeLstmCellInputs);
}
}
- if (toco_flags.quantize_weights()) {
- transformations.Add(new QuantizeWeights);
- }
transformations.Add(new ResolveConstantConcatenation);
RunGraphTransformations(model, "general graph transformations",
transformations);
+ if (toco_flags.quantize_weights()) {
+ // Run the quantize weights transformation after batchnorms have been
+ // folded into the weights.
+ RunGraphTransformations(model, "quantize weights transformation",
+ {new QuantizeWeights});
+ }
if (quantize_output) {
if (toco_flags.propagate_fake_quant_num_bits()) {
RunGraphTransformations(model,
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index 01113506d0..4ec74e351f 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -387,6 +387,7 @@ const char* OperatorTypeName(OperatorType type) {
HANDLE_OPERATORTYPENAME_CASE(Mean)
HANDLE_OPERATORTYPENAME_CASE(Svdf)
HANDLE_OPERATORTYPENAME_CASE(ArgMax)
+ HANDLE_OPERATORTYPENAME_CASE(ArgMin)
HANDLE_OPERATORTYPENAME_CASE(TopK_V2)
HANDLE_OPERATORTYPENAME_CASE(Unsupported)
HANDLE_OPERATORTYPENAME_CASE(Exp)
@@ -1265,8 +1266,13 @@ void InsertCopyOperator(Model* model, const string& source_array_name,
auto* copy_op = new TensorFlowReshapeOperator;
copy_op->inputs = {
source_array_name,
- CreateInt32Array(model, target_array_name + "_copy_shape", shape)};
+ CreateInt32Array(
+ model, AvailableArrayName(*model, target_array_name + "_copy_shape"),
+ shape)};
copy_op->outputs = {target_array_name};
+ if (target_array.has_shape()) {
+ copy_op->shape = target_array.shape().dims();
+ }
model->operators.emplace_back(copy_op);
}
diff --git a/tensorflow/contrib/lite/tools/BUILD b/tensorflow/contrib/lite/tools/BUILD
index a3df37358f..d070018e83 100644
--- a/tensorflow/contrib/lite/tools/BUILD
+++ b/tensorflow/contrib/lite/tools/BUILD
@@ -14,6 +14,7 @@ py_binary(
srcs = ["visualize.py"],
data = [
"//tensorflow/contrib/lite/schema:schema.fbs",
+ "//tensorflow/python:platform",
"@flatbuffers//:flatc",
],
srcs_version = "PY2AND3",
diff --git a/tensorflow/contrib/lite/tools/benchmark/README.md b/tensorflow/contrib/lite/tools/benchmark/README.md
index 93769305bd..f1e257ad10 100644
--- a/tensorflow/contrib/lite/tools/benchmark/README.md
+++ b/tensorflow/contrib/lite/tools/benchmark/README.md
@@ -115,7 +115,7 @@ E.g. for running the benchmark on big cores on Pixel 2 with a single thread one
can use the following command:
```
-adb shell tasket f0 /data/local/tmp/benchmark_model \
+adb shell taskset f0 /data/local/tmp/benchmark_model \
--graph=/data/local/tmp/mobilenet_quant_v1_224.tflite \
--input_layer="input" \
--input_layer_shape="1,224,224,3" \
diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc
index 08648bcfe2..19b9a9c7ba 100644
--- a/tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc
+++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc
@@ -98,10 +98,13 @@ void BenchmarkModel::LogFlags() {
<< "]";
}
+void BenchmarkModel::PrepareInputsAndOutputs() {}
+
Stat<int64_t> BenchmarkModel::Run(int num_times, RunType run_type) {
Stat<int64_t> run_stats;
TFLITE_LOG(INFO) << "Running benchmark for " << num_times << " iterations ";
for (int run = 0; run < num_times; run++) {
+ PrepareInputsAndOutputs();
listeners_.OnSingleRunStart(run_type);
int64_t start_us = profiling::time::NowMicros();
RunImpl();
diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_model.h b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.h
index 942e21f67a..3c7063b2d4 100644
--- a/tensorflow/contrib/lite/tools/benchmark/benchmark_model.h
+++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.h
@@ -150,6 +150,7 @@ class BenchmarkModel {
virtual std::vector<Flag> GetFlags();
virtual uint64_t ComputeInputBytes() = 0;
virtual tensorflow::Stat<int64_t> Run(int num_times, RunType run_type);
+ virtual void PrepareInputsAndOutputs();
virtual void RunImpl() = 0;
BenchmarkParams params_;
BenchmarkListeners listeners_;
diff --git a/tensorflow/contrib/lite/tools/visualize.py b/tensorflow/contrib/lite/tools/visualize.py
index f571dd59da..e07f899e4d 100644
--- a/tensorflow/contrib/lite/tools/visualize.py
+++ b/tensorflow/contrib/lite/tools/visualize.py
@@ -28,11 +28,24 @@ import json
import os
import sys
+from tensorflow.python.platform import resource_loader
+
# Schema to use for flatbuffers
_SCHEMA = "third_party/tensorflow/contrib/lite/schema/schema.fbs"
-# Where the binary will be once built in for the flatc converter
-_BINARY = "third_party/flatbuffers/flatc"
+# TODO(angerson): fix later when rules are simplified..
+_SCHEMA = resource_loader.get_path_to_datafile("../schema/schema.fbs")
+_BINARY = resource_loader.get_path_to_datafile("../../../../flatbuffers/flatc")
+# Account for different package positioning internal vs. external.
+if not os.path.exists(_BINARY):
+ _BINARY = resource_loader.get_path_to_datafile(
+ "../../../../../flatbuffers/flatc")
+
+if not os.path.exists(_SCHEMA):
+ raise RuntimeError("Sorry, schema file cannot be found at %r" % _SCHEMA)
+if not os.path.exists(_BINARY):
+ raise RuntimeError("Sorry, flatc is not available at %r" % _BINARY)
+
# A CSS description for making the visualizer
_CSS = """
diff --git a/tensorflow/contrib/metrics/BUILD b/tensorflow/contrib/metrics/BUILD
index 66cb493e5c..21cd34f73f 100644
--- a/tensorflow/contrib/metrics/BUILD
+++ b/tensorflow/contrib/metrics/BUILD
@@ -31,6 +31,7 @@ py_library(
"//tensorflow/python:check_ops",
"//tensorflow/python:confusion_matrix",
"//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:distribute",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:histogram_ops",
"//tensorflow/python:init_ops",
diff --git a/tensorflow/contrib/metrics/__init__.py b/tensorflow/contrib/metrics/__init__.py
index 5effea3596..88798d61b7 100644
--- a/tensorflow/contrib/metrics/__init__.py
+++ b/tensorflow/contrib/metrics/__init__.py
@@ -63,6 +63,7 @@ See the @{$python/contrib.metrics} guide.
@@aggregate_metrics
@@aggregate_metric_map
@@confusion_matrix
+@@f1_score
@@set_difference
@@set_intersection
@@set_size
diff --git a/tensorflow/contrib/metrics/python/metrics/classification.py b/tensorflow/contrib/metrics/python/metrics/classification.py
index 26aba1cc51..e553612269 100644
--- a/tensorflow/contrib/metrics/python/metrics/classification.py
+++ b/tensorflow/contrib/metrics/python/metrics/classification.py
@@ -22,6 +22,9 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import metrics_impl
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.training import distribute as distribute_lib
# TODO(nsilberman): move into metrics/python/ops/
@@ -62,3 +65,121 @@ def accuracy(predictions, labels, weights=None, name=None):
return math_ops.div(math_ops.reduce_sum(is_correct),
math_ops.reduce_sum(num_values))
return math_ops.reduce_mean(is_correct)
+
+
+def f1_score(labels, predictions, weights=None, num_thresholds=200,
+ metrics_collections=None, updates_collections=None, name=None):
+ """Computes the approximately best F1-score across different thresholds.
+
+ The f1_score function applies a range of thresholds to the predictions to
+ convert them from [0, 1] to bool. Precision and recall are computed by
+ comparing them to the labels. The F1-Score is then defined as
+ 2 * precision * recall / (precision + recall). The best one across the
+ thresholds is returned.
+
+ Disclaimer: In practice it may be desirable to choose the best threshold on
+ the validation set and evaluate the F1 score with this threshold on a
+ separate test set. Or it may be desirable to use a fixed threshold (e.g. 0.5).
+
+ This function internally creates four local variables, `true_positives`,
+ `true_negatives`, `false_positives` and `false_negatives` that are used to
+ compute the pairs of recall and precision values for a linearly spaced set of
+ thresholds from which the best f1-score is derived.
+
+ This value is ultimately returned as `f1-score`, an idempotent operation that
+ computes the F1-score (computed using the aforementioned variables). The
+ `num_thresholds` variable controls the degree of discretization with larger
+ numbers of thresholds more closely approximating the true best F1-score.
+
+ For estimation of the metric over a stream of data, the function creates an
+ `update_op` operation that updates these variables and returns the F1-score.
+
+ Example usage with a custom estimator:
+ def model_fn(features, labels, mode):
+ predictions = make_predictions(features)
+ loss = make_loss(predictions, labels)
+ train_op = tf.contrib.training.create_train_op(
+ total_loss=loss,
+ optimizer='Adam')
+ eval_metric_ops = {'f1': f1_score(labels, predictions)}
+ return tf.estimator.EstimatorSpec(
+ mode=mode,
+ predictions=predictions,
+ loss=loss,
+ train_op=train_op,
+ eval_metric_ops=eval_metric_ops,
+ export_outputs=export_outputs)
+ estimator = tf.estimator.Estimator(model_fn=model_fn)
+
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
+
+ Args:
+ labels: A `Tensor` whose shape matches `predictions`. Will be cast to
+ `bool`.
+ predictions: A floating point `Tensor` of arbitrary shape and whose values
+ are in the range `[0, 1]`.
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
+ be either `1`, or the same as the corresponding `labels` dimension).
+ num_thresholds: The number of thresholds to use when discretizing the roc
+ curve.
+ metrics_collections: An optional list of collections that `f1_score` should
+ be added to.
+ updates_collections: An optional list of collections that `update_op` should
+ be added to.
+ name: An optional variable_scope name.
+
+ Returns:
+ f1_score: A scalar `Tensor` representing the current best f1-score across
+ different thresholds.
+ update_op: An operation that increments the `true_positives`,
+ `true_negatives`, `false_positives` and `false_negatives` variables
+ appropriately and whose value matches the `f1_score`.
+
+ Raises:
+ ValueError: If `predictions` and `labels` have mismatched shapes, or if
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
+ either `metrics_collections` or `updates_collections` are not a list or
+ tuple.
+ """
+ with variable_scope.variable_scope(
+ name, 'f1', (labels, predictions, weights)):
+ predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access
+ predictions=predictions, labels=labels, weights=weights)
+ # To account for floating point imprecisions / avoid division by zero.
+ epsilon = 1e-7
+ thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
+ for i in range(num_thresholds - 2)]
+ thresholds = [0.0 - epsilon] + thresholds + [1.0 + epsilon]
+
+ # Confusion matrix.
+ values, update_ops = metrics_impl._confusion_matrix_at_thresholds( # pylint: disable=protected-access
+ labels, predictions, thresholds, weights, includes=('tp', 'fp', 'fn'))
+
+ # Compute precision and recall at various thresholds.
+ def compute_best_f1_score(tp, fp, fn, name):
+ precision_at_t = math_ops.div(tp, epsilon + tp + fp,
+ name='precision_' + name)
+ recall_at_t = math_ops.div(tp, epsilon + tp + fn, name='recall_' + name)
+ # Compute F1 score.
+ f1_at_thresholds = (
+ 2.0 * precision_at_t * recall_at_t /
+ (precision_at_t + recall_at_t + epsilon))
+ return math_ops.reduce_max(f1_at_thresholds)
+
+ def f1_across_towers(_, values):
+ best_f1 = compute_best_f1_score(tp=values['tp'], fp=values['fp'],
+ fn=values['fn'], name='value')
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, best_f1)
+ return best_f1
+
+ best_f1 = distribute_lib.get_tower_context().merge_call(
+ f1_across_towers, values)
+
+ update_op = compute_best_f1_score(tp=update_ops['tp'], fp=update_ops['fp'],
+ fn=update_ops['fn'], name='update')
+ if updates_collections:
+ ops.add_to_collections(updates_collections, update_op)
+
+ return best_f1, update_op
diff --git a/tensorflow/contrib/metrics/python/metrics/classification_test.py b/tensorflow/contrib/metrics/python/metrics/classification_test.py
index fa0f12d029..3d0b81c1be 100644
--- a/tensorflow/contrib/metrics/python/metrics/classification_test.py
+++ b/tensorflow/contrib/metrics/python/metrics/classification_test.py
@@ -18,9 +18,16 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy as np
+
from tensorflow.contrib.metrics.python.metrics import classification
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -108,5 +115,200 @@ class ClassificationTest(test.TestCase):
self.assertEqual(result, 0.5)
+class F1ScoreTest(test.TestCase):
+
+ def setUp(self):
+ super(F1ScoreTest, self).setUp()
+ np.random.seed(1)
+
+ def testVars(self):
+ classification.f1_score(
+ predictions=array_ops.ones((10, 1)),
+ labels=array_ops.ones((10, 1)),
+ num_thresholds=3)
+ expected = {'f1/true_positives:0', 'f1/false_positives:0',
+ 'f1/false_negatives:0'}
+ self.assertEquals(
+ expected, set(v.name for v in variables.local_variables()))
+ self.assertEquals(
+ set(expected), set(v.name for v in variables.local_variables()))
+ self.assertEquals(
+ set(expected),
+ set(v.name for v in ops.get_collection(ops.GraphKeys.METRIC_VARIABLES)))
+
+ def testMetricsCollection(self):
+ my_collection_name = '__metrics__'
+ f1, _ = classification.f1_score(
+ predictions=array_ops.ones((10, 1)),
+ labels=array_ops.ones((10, 1)),
+ num_thresholds=3,
+ metrics_collections=[my_collection_name])
+ self.assertListEqual(ops.get_collection(my_collection_name), [f1])
+
+ def testUpdatesCollection(self):
+ my_collection_name = '__updates__'
+ _, f1_op = classification.f1_score(
+ predictions=array_ops.ones((10, 1)),
+ labels=array_ops.ones((10, 1)),
+ num_thresholds=3,
+ updates_collections=[my_collection_name])
+ self.assertListEqual(ops.get_collection(my_collection_name), [f1_op])
+
+ def testValueTensorIsIdempotent(self):
+ predictions = random_ops.random_uniform(
+ (10, 3), maxval=1, dtype=dtypes.float32, seed=1)
+ labels = random_ops.random_uniform(
+ (10, 3), maxval=2, dtype=dtypes.int64, seed=2)
+ f1, f1_op = classification.f1_score(predictions, labels, num_thresholds=3)
+
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+
+ # Run several updates.
+ for _ in range(10):
+ sess.run([f1_op])
+
+ # Then verify idempotency.
+ initial_f1 = f1.eval()
+ for _ in range(10):
+ self.assertAllClose(initial_f1, f1.eval())
+
+ def testAllCorrect(self):
+ inputs = np.random.randint(0, 2, size=(100, 1))
+
+ with self.test_session() as sess:
+ predictions = constant_op.constant(inputs, dtype=dtypes.float32)
+ labels = constant_op.constant(inputs)
+ f1, f1_op = classification.f1_score(predictions, labels, num_thresholds=3)
+
+ sess.run(variables.local_variables_initializer())
+ sess.run([f1_op])
+
+ self.assertEqual(1, f1.eval())
+
+ def testSomeCorrect(self):
+ predictions = constant_op.constant(
+ [1, 0, 1, 0], shape=(1, 4), dtype=dtypes.float32)
+ labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
+ f1, f1_op = classification.f1_score(predictions, labels, num_thresholds=1)
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+ sess.run([f1_op])
+ # Threshold 0 will have around 0.5 precision and 1 recall yielding an F1
+ # score of 2 * 0.5 * 1 / (1 + 0.5).
+ self.assertAlmostEqual(2 * 0.5 * 1 / (1 + 0.5), f1.eval())
+
+ def testAllIncorrect(self):
+ inputs = np.random.randint(0, 2, size=(10000, 1))
+
+ with self.test_session() as sess:
+ predictions = constant_op.constant(inputs, dtype=dtypes.float32)
+ labels = constant_op.constant(1 - inputs, dtype=dtypes.float32)
+ f1, f1_op = classification.f1_score(predictions, labels, num_thresholds=3)
+
+ sess.run(variables.local_variables_initializer())
+ sess.run([f1_op])
+
+ # Threshold 0 will have around 0.5 precision and 1 recall yielding an F1
+ # score of 2 * 0.5 * 1 / (1 + 0.5).
+ self.assertAlmostEqual(2 * 0.5 * 1 / (1 + 0.5), f1.eval(), places=2)
+
+ def testWeights1d(self):
+ with self.test_session() as sess:
+ predictions = constant_op.constant(
+ [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes.float32)
+ labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
+ weights = constant_op.constant(
+ [[0], [1]], shape=(2, 1), dtype=dtypes.float32)
+ f1, f1_op = classification.f1_score(predictions, labels, weights,
+ num_thresholds=3)
+ sess.run(variables.local_variables_initializer())
+ sess.run([f1_op])
+
+ self.assertAlmostEqual(1.0, f1.eval(), places=5)
+
+ def testWeights2d(self):
+ with self.test_session() as sess:
+ predictions = constant_op.constant(
+ [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes.float32)
+ labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
+ weights = constant_op.constant(
+ [[0, 0], [1, 1]], shape=(2, 2), dtype=dtypes.float32)
+ f1, f1_op = classification.f1_score(predictions, labels, weights,
+ num_thresholds=3)
+ sess.run(variables.local_variables_initializer())
+ sess.run([f1_op])
+
+ self.assertAlmostEqual(1.0, f1.eval(), places=5)
+
+ def testZeroLabelsPredictions(self):
+ with self.test_session() as sess:
+ predictions = array_ops.zeros([4], dtype=dtypes.float32)
+ labels = array_ops.zeros([4])
+ f1, f1_op = classification.f1_score(predictions, labels, num_thresholds=3)
+ sess.run(variables.local_variables_initializer())
+ sess.run([f1_op])
+
+ self.assertAlmostEqual(0.0, f1.eval(), places=5)
+
+ def testWithMultipleUpdates(self):
+ num_samples = 1000
+ batch_size = 10
+ num_batches = int(num_samples / batch_size)
+
+ # Create the labels and data.
+ labels = np.random.randint(0, 2, size=(num_samples, 1))
+ noise = np.random.normal(0.0, scale=0.2, size=(num_samples, 1))
+ predictions = 0.4 + 0.2 * labels + noise
+ predictions[predictions > 1] = 1
+ predictions[predictions < 0] = 0
+ thresholds = [-0.01, 0.5, 1.01]
+
+ expected_max_f1 = -1.0
+ for threshold in thresholds:
+ tp = 0
+ fp = 0
+ fn = 0
+ tn = 0
+ for i in range(num_samples):
+ if predictions[i] >= threshold:
+ if labels[i] == 1:
+ tp += 1
+ else:
+ fp += 1
+ else:
+ if labels[i] == 1:
+ fn += 1
+ else:
+ tn += 1
+ epsilon = 1e-7
+ expected_prec = tp / (epsilon + tp + fp)
+ expected_rec = tp / (epsilon + tp + fn)
+ expected_f1 = (2 * expected_prec * expected_rec /
+ (epsilon + expected_prec + expected_rec))
+ if expected_f1 > expected_max_f1:
+ expected_max_f1 = expected_f1
+
+ labels = labels.astype(np.float32)
+ predictions = predictions.astype(np.float32)
+ tf_predictions, tf_labels = (dataset_ops.Dataset
+ .from_tensor_slices((predictions, labels))
+ .repeat()
+ .batch(batch_size)
+ .make_one_shot_iterator()
+ .get_next())
+ f1, f1_op = classification.f1_score(tf_labels, tf_predictions,
+ num_thresholds=3)
+
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+ for _ in range(num_batches):
+ sess.run([f1_op])
+ # Since this is only approximate, we can't expect a 6 digits match.
+ # Although with higher number of samples/thresholds we should see the
+ # accuracy improving
+ self.assertAlmostEqual(expected_max_f1, f1.eval(), 2)
+
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py
index ef34f7bf7b..93050a3ae3 100644
--- a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py
+++ b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py
@@ -77,7 +77,7 @@ class LossScaleOptimizer(optimizer.Optimizer):
If gradients clipping is applied, one can call
`optimizer.compute_gradients()` and `optimizer.apply_gradients()`
- seperately.
+ separately.
Notice the following way of using LossScaleOptimizer is not intended. Always
use `loss_scale_optimizer.compute_gradients()` to compute gradients instead of
diff --git a/tensorflow/contrib/mpi_collectives/mpi_ops.py b/tensorflow/contrib/mpi_collectives/mpi_ops.py
new file mode 100644
index 0000000000..bd7096d9ce
--- /dev/null
+++ b/tensorflow/contrib/mpi_collectives/mpi_ops.py
@@ -0,0 +1,163 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Inter-process communication using MPI."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import load_library
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import resource_loader
+from tensorflow.python.platform import tf_logging as logging
+
+
+def _load_library(name, op_list=None):
+ """Loads a .so file containing the specified operators.
+
+ Args:
+ name: The name of the .so file to load.
+ op_list: A list of names of operators that the library should have. If None
+ then the .so file's contents will not be verified.
+
+ Raises:
+ NameError if one of the required ops is missing.
+ """
+ try:
+ filename = resource_loader.get_path_to_datafile(name)
+ library = load_library.load_op_library(filename)
+ for expected_op in (op_list or []):
+ for lib_op in library.OP_LIST.op:
+ if lib_op.name == expected_op:
+ break
+ else:
+ raise NameError('Could not find operator %s in dynamic library %s' %
+ (expected_op, name))
+ return library
+ except errors.NotFoundError:
+ logging.warning('%s file could not be loaded.', name)
+
+
+MPI_LIB = _load_library(
+ 'mpi_collectives.so',
+ ['MPISize', 'MPIRank', 'MPILocalRank', 'MPIAllgather', 'MPIAllreduce'])
+
+
+def size(name=None):
+ """An op which returns the number of MPI processes.
+
+ This is equivalent to running `MPI_Comm_size(MPI_COMM_WORLD, ...)` to get the
+ size of the global communicator.
+
+ Returns:
+ An integer scalar containing the number of MPI processes.
+ """
+ return MPI_LIB.mpi_size(name=name)
+
+
+ops.NotDifferentiable('MPISize')
+
+
+def rank(name=None):
+ """An op which returns the MPI rank of the calling process.
+
+ This is equivalent to running `MPI_Comm_rank(MPI_COMM_WORLD, ...)` to get the
+ rank of the current process in the global communicator.
+
+ Returns:
+ An integer scalar with the MPI rank of the calling process.
+ """
+ return MPI_LIB.mpi_rank(name=name)
+
+
+ops.NotDifferentiable('MPIRank')
+
+
+def init(name=None):
+ """An op which initializes MPI on the device on which it is run.
+
+ All future MPI ops must be run on the same device that the `init` op was run
+ on.
+ """
+ return MPI_LIB.mpi_init(name=name)
+
+
+ops.NotDifferentiable('MPIInit')
+
+
+def local_rank(name=None):
+ """An op which returns the local MPI rank of the calling process, within the
+ node that it is running on. For example, if there are seven processes running
+ on a node, their local ranks will be zero through six, inclusive.
+
+ This is equivalent to running `MPI_Comm_rank(...)` on a new communicator
+ which only includes processes on the same node.
+
+ Returns:
+ An integer scalar with the local MPI rank of the calling process.
+ """
+ return MPI_LIB.mpi_local_rank(name=name)
+
+
+ops.NotDifferentiable('MPILocalRank')
+
+
+def _allreduce(tensor, name=None):
+ """An op which sums an input tensor over all the MPI processes.
+
+ The reduction operation is keyed by the name of the op. The tensor type and
+ shape must be the same on all MPI processes for a given name. The reduction
+ will not start until all processes are ready to send and receive the tensor.
+
+ Returns:
+ A tensor of the same shape and type as `tensor`, summed across all
+ processes.
+ """
+ return MPI_LIB.mpi_allreduce(tensor, name=name)
+
+
+ops.NotDifferentiable('MPIAllreduce')
+
+
+def allgather(tensor, name=None):
+ """An op which concatenates the input tensor with the same input tensor on
+ all other MPI processes.
+
+ The concatenation is done on the first dimension, so the input tensors on the
+ different processes must have the same rank and shape, except for the first
+ dimension, which is allowed to be different.
+
+ Returns:
+ A tensor of the same type as `tensor`, concatenated on dimension zero
+ across all processes. The shape is identical to the input shape, except for
+ the first dimension, which may be greater and is the sum of all first
+ dimensions of the tensors in different MPI processes.
+ """
+ # Specify that first allgather is to collect the tensor gather sizes,
+ # indicated by passing in a scalar (0-D tensor) of value 0
+ sizes_flag = tf.constant(0, dtype=tf.int64, name='size_flag_const')
+ my_size = tf.slice(
+ tf.shape(tensor, out_type=tf.int64), [0], [1], name='size_slice')
+ if name is None:
+ name = 'allgather'
+ sizing_name = '{}_sizing'.format(name)
+ sizes = MPI_LIB.mpi_allgather(my_size, sizes_flag, name=sizing_name)
+ return MPI_LIB.mpi_allgather(tensor, sizes, name=name)
+
+
+ops.NotDifferentiable('MPIAllgather')
diff --git a/tensorflow/contrib/mpi_collectives/ring.cc b/tensorflow/contrib/mpi_collectives/ring.cc
new file mode 100644
index 0000000000..d93233eb21
--- /dev/null
+++ b/tensorflow/contrib/mpi_collectives/ring.cc
@@ -0,0 +1,80 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef TENSORFLOW_USE_MPI
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/contrib/mpi_collectives/ring.h"
+
+namespace tensorflow {
+namespace contrib {
+namespace mpi {
+
+using CPUDevice = Eigen::ThreadPoolDevice;
+
+extern template MPI_Datatype MPIType<float>();
+extern template MPI_Datatype MPIType<int>();
+extern template MPI_Datatype MPIType<long long>();
+extern template DataType TensorFlowDataType<float>();
+extern template DataType TensorFlowDataType<int>();
+extern template DataType TensorFlowDataType<long long>();
+
+// Generate all necessary specializations for RingAllreduce.
+template Status RingAllreduce<CPUDevice, int>(OpKernelContext*, const Tensor*,
+ Tensor*, Tensor*);
+template Status RingAllreduce<CPUDevice, long long>(OpKernelContext*,
+ const Tensor*, Tensor*,
+ Tensor*);
+template Status RingAllreduce<CPUDevice, float>(OpKernelContext*, const Tensor*,
+ Tensor*, Tensor*);
+
+// Generate all necessary specializations for RingAllgather.
+template Status RingAllgather<CPUDevice, int>(OpKernelContext*, const Tensor*,
+ const std::vector<size_t>&,
+ Tensor*);
+template Status RingAllgather<CPUDevice, long long>(OpKernelContext*,
+ const Tensor*,
+ const std::vector<size_t>&,
+ Tensor*);
+template Status RingAllgather<CPUDevice, float>(OpKernelContext*, const Tensor*,
+ const std::vector<size_t>&,
+ Tensor*);
+
+// Copy data on a CPU using a straight-forward memcpy.
+template <>
+void CopyTensorData<CPUDevice>(void* dst, void* src, size_t size) {
+ std::memcpy(dst, src, size);
+};
+
+// Accumulate values on a CPU.
+#define GENERATE_ACCUMULATE(type) \
+ template <> \
+ void AccumulateTensorData<CPUDevice, type>(type * dst, type * src, \
+ size_t size) { \
+ for (unsigned int i = 0; i < size; i++) { \
+ dst[i] += src[i]; \
+ } \
+ };
+GENERATE_ACCUMULATE(int);
+GENERATE_ACCUMULATE(long long);
+GENERATE_ACCUMULATE(float);
+#undef GENERATE_ACCUMULATE
+
+} // namespace mpi
+} // namespace contrib
+} // namespace tensorflow
+
+#endif // TENSORFLOW_USE_MPI
diff --git a/tensorflow/contrib/mpi_collectives/ring.cu.cc b/tensorflow/contrib/mpi_collectives/ring.cu.cc
new file mode 100644
index 0000000000..2f3eef366a
--- /dev/null
+++ b/tensorflow/contrib/mpi_collectives/ring.cu.cc
@@ -0,0 +1,117 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef TENSORFLOW_USE_MPI
+
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/contrib/mpi_collectives/ring.h"
+
+namespace tensorflow {
+namespace contrib {
+namespace mpi {
+
+using CPUDevice = Eigen::ThreadPoolDevice;
+
+template <>
+MPI_Datatype MPIType<float>() {
+ return MPI_FLOAT;
+};
+template <>
+MPI_Datatype MPIType<int>() {
+ return MPI_INT;
+};
+template <>
+MPI_Datatype MPIType<long long>() {
+ return MPI_LONG_LONG;
+};
+
+template <>
+DataType TensorFlowDataType<float>() {
+ return DT_FLOAT;
+};
+template <>
+DataType TensorFlowDataType<int>() {
+ return DT_INT32;
+};
+template <>
+DataType TensorFlowDataType<long long>() {
+ return DT_INT64;
+};
+
+// Generate all necessary specializations for RingAllreduce.
+template Status RingAllreduce<GPUDevice, int>(OpKernelContext*, const Tensor*,
+ Tensor*, Tensor*);
+template Status RingAllreduce<GPUDevice, long long>(OpKernelContext*,
+ const Tensor*, Tensor*,
+ Tensor*);
+template Status RingAllreduce<GPUDevice, float>(OpKernelContext*, const Tensor*,
+ Tensor*, Tensor*);
+
+// Generate all necessary specializations for RingAllgather.
+template Status RingAllgather<GPUDevice, int>(OpKernelContext*, const Tensor*,
+ const std::vector<size_t>&,
+ Tensor*);
+template Status RingAllgather<GPUDevice, long long>(OpKernelContext*,
+ const Tensor*,
+ const std::vector<size_t>&,
+ Tensor*);
+template Status RingAllgather<GPUDevice, float>(OpKernelContext*, const Tensor*,
+ const std::vector<size_t>&,
+ Tensor*);
+
+// Synchronously copy data on the GPU, using a different stream than the default
+// and than TensorFlow to avoid synchronizing on operations unrelated to the
+// allreduce.
+template <>
+void CopyTensorData<GPUDevice>(void* dst, void* src, size_t size) {
+ auto stream = CudaStreamForMPI();
+ cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToDevice, stream);
+ cudaStreamSynchronize(stream);
+};
+
+// Elementwise accumulation kernel for GPU.
+template <typename T>
+__global__ void elemwise_accum(T* out, const T* in, const size_t N) {
+ for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
+ i += blockDim.x * gridDim.x) {
+ out[i] += in[i];
+ }
+}
+
+// Synchronously accumulate tensors on the GPU, using a different stream than
+// the default and than TensorFlow to avoid synchronizing on operations
+// unrelated to the allreduce.
+#define GENERATE_ACCUMULATE(type) \
+ template <> \
+ void AccumulateTensorData<GPUDevice, type>(type * dst, type * src, \
+ size_t size) { \
+ auto stream = CudaStreamForMPI(); \
+ elemwise_accum<type><<<32, 256, 0, stream>>>(dst, src, size); \
+ cudaStreamSynchronize(stream); \
+ };
+GENERATE_ACCUMULATE(int);
+GENERATE_ACCUMULATE(long long);
+GENERATE_ACCUMULATE(float);
+#undef GENERATE_ACCUMULATE
+
+} // namespace mpi
+} // namespace contrib
+} // namespace tensorflow
+#endif // GOOGLE_CUDA
+
+#endif // TENSORFLOW_USE_MPI
diff --git a/tensorflow/contrib/mpi_collectives/ring.h b/tensorflow/contrib/mpi_collectives/ring.h
new file mode 100644
index 0000000000..cae57ce60e
--- /dev/null
+++ b/tensorflow/contrib/mpi_collectives/ring.h
@@ -0,0 +1,327 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_MPI_H_
+#define TENSORFLOW_CONTRIB_MPI_H_
+
+#ifdef TENSORFLOW_USE_MPI
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/tensor_types.h"
+
+#if GOOGLE_CUDA
+#include "cuda_runtime.h"
+#endif
+
+// Needed to avoid header issues with C++-supporting MPI implementations
+#define OMPI_SKIP_MPICXX
+#include "third_party/mpi/mpi.h"
+
+#define TAG_TENSOR 12
+
+namespace tensorflow {
+namespace contrib {
+namespace mpi {
+
+using CPUDevice = Eigen::ThreadPoolDevice;
+using GPUDevice = Eigen::GpuDevice;
+
+// Convert from templated types to values we can pass to MPI.
+template <typename T>
+MPI_Datatype MPIType();
+
+// Convert from templated types to TensorFlow data types.
+template <typename T>
+DataType TensorFlowDataType();
+
+#define MPI_REQUIRES_OK(MPI_STATUS) \
+ if ((MPI_STATUS) != MPI_SUCCESS) { \
+ return errors::Unknown("MPI operation failed unexpectedly."); \
+ }
+
+// Copy data from one tensor to another tensor.
+// This uses a custom CUDA stream on GPU, which is necessary to overlay the
+// backpropagation computations with the allreduce.
+template <typename Device>
+void CopyTensorData(void* destination, void* source, size_t size);
+
+// Add a tensor into another tensor, accumulating in place.
+// This uses a custom CUDA stream on GPU, which is necessary to overlay the
+// backpropagation computations with the allreduce.
+template <typename Device, typename T>
+void AccumulateTensorData(T* destination, T* source, size_t size);
+
+// We need to get the right stream for doing CUDA memory transfers and
+// operations, which is possibly different from the standard TensorFlow stream.
+#if GOOGLE_CUDA
+cudaStream_t CudaStreamForMPI();
+#endif
+
+/* Perform a ring allreduce on the data. Allocate the necessary output tensor
+ * and store it in the output parameter.
+ *
+ * Assumes that all MPI processes are doing an allreduce of the same tensor,
+ * with the same dimensions.
+ *
+ * A ring allreduce is a bandwidth-optimal way to do an allreduce. To do the
+ * allreduce, the nodes involved are arranged in a ring:
+ *
+ * .--0--.
+ * / \
+ * 3 1
+ * \ /
+ * *--2--*
+ *
+ * Each node always sends to the next clockwise node in the ring, and receives
+ * from the previous one.
+ *
+ * The allreduce is done in two parts: a scatter-reduce and an allgather. In
+ * the scatter reduce, a reduction is done, so that each node ends up with a
+ * chunk of the final output tensor which has contributions from all other
+ * nodes. In the allgather, those chunks are distributed among all the nodes,
+ * so that all nodes have the entire output tensor.
+ *
+ * Both of these operations are done by dividing the input tensor into N
+ * evenly sized chunks (where N is the number of nodes in the ring).
+ *
+ * The scatter-reduce is done in N-1 steps. In the ith step, node j will send
+ * the (j - i)th chunk and receive the (j - i - 1)th chunk, adding it in to
+ * its existing data for that chunk. For example, in the first iteration with
+ * the ring depicted above, you will have the following transfers:
+ *
+ * Segment 0: Node 0 --> Node 1
+ * Segment 1: Node 1 --> Node 2
+ * Segment 2: Node 2 --> Node 3
+ * Segment 3: Node 3 --> Node 0
+ *
+ * In the second iteration, you'll have the following transfers:
+ *
+ * Segment 0: Node 1 --> Node 2
+ * Segment 1: Node 2 --> Node 3
+ * Segment 2: Node 3 --> Node 0
+ * Segment 3: Node 0 --> Node 1
+ *
+ * After this iteration, Node 2 has 3 of the four contributions to Segment 0.
+ * The last iteration has the following transfers:
+ *
+ * Segment 0: Node 2 --> Node 3
+ * Segment 1: Node 3 --> Node 0
+ * Segment 2: Node 0 --> Node 1
+ * Segment 3: Node 1 --> Node 2
+ *
+ * After this iteration, Node 3 has the fully accumulated Segment 0; Node 0
+ * has the fully accumulated Segment 1; and so on. The scatter-reduce is
+ * complete.
+ *
+ * Next, the allgather distributes these fully accumululated chunks across all
+ * nodes. Communication proceeds in the same ring, once again in N-1 steps. At
+ * the ith step, node j will send chunk (j - i + 1) and receive chunk (j - i).
+ * For example, at the first iteration, the following transfers will occur:
+ *
+ * Segment 0: Node 3 --> Node 0
+ * Segment 1: Node 0 --> Node 1
+ * Segment 2: Node 1 --> Node 2
+ * Segment 3: Node 2 --> Node 3
+ *
+ * After the first iteration, Node 0 will have a fully accumulated Segment 0
+ * (from Node 3) and Segment 1. In the next iteration, Node 0 will send its
+ * just-received Segment 0 onward to Node 1, and receive Segment 3 from Node 3.
+ * After this has continued for N - 1 iterations, all nodes will have a the
+ * fully accumulated tensor.
+ *
+ * Each node will do (N-1) sends for the scatter-reduce and (N-1) sends for the
+ * allgather. Each send will contain K / N bytes, if there are K bytes in the
+ * original tensor on every node. Thus, each node sends and receives 2K(N - 1)/N
+ * bytes of data, and the performance of the allreduce (assuming no latency in
+ * connections) is constrained by the slowest interconnect between the nodes.
+ *
+ */
+template <typename Device, typename T>
+Status RingAllreduce(OpKernelContext* context, const Tensor* input,
+ Tensor* temp, Tensor* output) {
+ // Acquire MPI size and rank
+ int n, r;
+ MPI_REQUIRES_OK(MPI_Comm_size(MPI_COMM_WORLD, &n));
+ MPI_REQUIRES_OK(MPI_Comm_rank(MPI_COMM_WORLD, &r));
+
+ T* buffer = (T*)output->tensor_data().data();
+
+ CopyTensorData<Device>((void*)buffer, (void*)input->tensor_data().data(),
+ output->tensor_data().size());
+
+ // Calculate segment sizes and segment ends
+ const size_t elements_to_reduce = input->NumElements();
+ const size_t segment_size = elements_to_reduce / n;
+ std::vector<size_t> segment_sizes(n, segment_size);
+
+ const size_t residual = elements_to_reduce % n;
+ for (size_t i = 0; i < residual; ++i) {
+ segment_sizes[i]++;
+ }
+
+ std::vector<size_t> segment_starts(n);
+ segment_starts[0] = 0;
+ for (size_t i = 1; i < segment_starts.size(); ++i) {
+ segment_starts[i] = segment_starts[i - 1] + segment_sizes[i - 1];
+ }
+
+ assert(segment_starts[n - 1] + segment_sizes[n - 1] == elements_to_reduce);
+
+ T* segment_recv = (T*)temp->tensor_data().data();
+
+ // Receive from your left neighbor with wrap-around
+ const size_t recv_from = ((r - 1) + n) % n;
+
+ // Send to your right neighbor with wrap-around
+ const size_t send_to = (r + 1) % n;
+
+ MPI_Status recv_status;
+ MPI_Request recv_req;
+
+ // Now start ring. At every step, for every rank, we iterate through
+ // segments with wraparound and send and recv from our neighbors and reduce
+ // locally. At the i'th iteration, rank r, sends segment (r-i) and receives
+ // segment (r-i-1).
+ for (int i = 0; i < n - 1; i++) {
+ const size_t send_seg_id = ((r - i) + n) % n;
+ const size_t recv_seg_id = ((r - i - 1) + n) % n;
+
+ T* segment_send = &(buffer[segment_starts[send_seg_id]]);
+
+ MPI_REQUIRES_OK(MPI_Irecv(segment_recv, segment_sizes[recv_seg_id],
+ MPIType<T>(), recv_from, TAG_TENSOR,
+ MPI_COMM_WORLD, &recv_req));
+
+ MPI_REQUIRES_OK(MPI_Send(segment_send, segment_sizes[send_seg_id],
+ MPIType<T>(), send_to, TAG_TENSOR,
+ MPI_COMM_WORLD));
+
+ T* segment_update = &(buffer[segment_starts[recv_seg_id]]);
+
+ // Wait for recv to complete before reduction
+ MPI_REQUIRES_OK(MPI_Wait(&recv_req, &recv_status));
+
+ const size_t recv_seg_size = segment_sizes[recv_seg_id];
+ AccumulateTensorData<Device, T>(segment_update, segment_recv,
+ recv_seg_size);
+ }
+
+ // Now start pipelined ring allgather. At every step, for every rank, we
+ // iterate through segments with wraparound and send and recv from our
+ // neighbors. At the i'th iteration, rank r, sends segment (r-i+1) and
+ // receives segment (r-i).
+ for (size_t i = 0; i < n - 1; ++i) {
+ const size_t send_seg_id = ((r - i + 1) + n) % n;
+ const size_t recv_seg_id = ((r - i) + n) % n;
+
+ // Segment to send - at every iteration we send segment (r-i+1)
+ T* segment_send = &(buffer[segment_starts[send_seg_id]]);
+
+ // Segment to recv - at every iteration we receive segment (r-i)
+ T* segment_recv = &(buffer[segment_starts[recv_seg_id]]);
+
+ MPI_REQUIRES_OK(MPI_Sendrecv(
+ segment_send, segment_sizes[send_seg_id], MPIType<T>(), send_to,
+ TAG_TENSOR, segment_recv, segment_sizes[recv_seg_id], MPIType<T>(),
+ recv_from, TAG_TENSOR, MPI_COMM_WORLD, &recv_status));
+ }
+
+ return Status::OK();
+}
+
+// Perform a ring allgather on a Tensor. Other ranks may allgather with a
+// tensor which differs in the first dimension only; all other dimensions must
+// be the same.
+//
+// For more information on the ring allgather, read the documentation for the
+// ring allreduce, which includes a ring allgather.
+template <typename Device, typename T>
+Status RingAllgather(OpKernelContext* context, const Tensor* input,
+ const std::vector<size_t>& sizes, Tensor* output) {
+ // Acquire MPI size and rank
+ int n, r;
+ MPI_REQUIRES_OK(MPI_Comm_size(MPI_COMM_WORLD, &n));
+ MPI_REQUIRES_OK(MPI_Comm_rank(MPI_COMM_WORLD, &r));
+
+ assert(sizes.size() == n);
+ assert(input->dim_size(0) == sizes[r]);
+
+ // Compute number of elements in every "row". We can't compute number of
+ // elements in every chunks, because those chunks are variable length.
+ size_t elements_per_row = 1;
+ for (int i = 1; i < input->shape().dims(); i++) {
+ elements_per_row *= input->dim_size(i);
+ }
+
+ // Copy data from input tensor to correct place in output tensor.
+ std::vector<size_t> segment_starts(n);
+ segment_starts[0] = 0;
+ for (int i = 1; i < n; i++) {
+ segment_starts[i] = segment_starts[i - 1] + elements_per_row * sizes[i - 1];
+ }
+ size_t offset = segment_starts[r];
+
+ // Copy data to the right offset for this rank.
+ T* buffer = (T*)output->tensor_data().data();
+ CopyTensorData<Device>((void*)(buffer + offset),
+ (void*)input->tensor_data().data(),
+ elements_per_row * sizes[r] * sizeof(T));
+
+ // Receive from your left neighbor with wrap-around
+ const size_t recv_from = ((r - 1) + n) % n;
+
+ // Send to your right neighbor with wrap-around
+ const size_t send_to = (r + 1) % n;
+
+ // Perform a ring allgather. At every step, for every rank, we iterate
+ // through segments with wraparound and send and recv from our neighbors.
+ // At the i'th iteration, rank r, sends segment (r-i) and receives segment
+ // (r-1-i).
+ MPI_Status recv_status;
+ for (size_t i = 0; i < n - 1; ++i) {
+ const size_t send_seg_id = ((r - i) + n) % n;
+ const size_t recv_seg_id = ((r - i - 1) + n) % n;
+
+ // Segment to send - at every iteration we send segment (r-i)
+ size_t offset_send = segment_starts[send_seg_id];
+ size_t rows_send = sizes[send_seg_id];
+ T* segment_send = &(buffer[offset_send]);
+
+ // Segment to recv - at every iteration we receive segment (r-1-i)
+ size_t offset_recv = segment_starts[recv_seg_id];
+ size_t rows_recv = sizes[recv_seg_id];
+ T* segment_recv = &(buffer[offset_recv]);
+
+ MPI_REQUIRES_OK(MPI_Sendrecv(
+ segment_send, elements_per_row * rows_send, MPIType<T>(), send_to,
+ TAG_TENSOR, segment_recv, elements_per_row * rows_recv, MPIType<T>(),
+ recv_from, TAG_TENSOR, MPI_COMM_WORLD, &recv_status));
+ }
+
+ return Status::OK();
+}
+
+} // namespace mpi
+} // namespace contrib
+} // namespace tensorflow
+
+#endif // TENSORFLOW_USE_MPI
+
+#undef TENSORFLOW_CONTRIB_MPI_H_
+#endif // TENSORFLOW_CONTRIB_MPI_H_
diff --git a/tensorflow/contrib/opt/python/training/addsign_test.py b/tensorflow/contrib/opt/python/training/addsign_test.py
index 08d45ed73f..628a735e72 100644
--- a/tensorflow/contrib/opt/python/training/addsign_test.py
+++ b/tensorflow/contrib/opt/python/training/addsign_test.py
@@ -214,7 +214,7 @@ class AddSignTest(test.TestCase):
# Run 7 steps of AddSign
# first 4 steps with positive gradient
# last 3 steps with negative gradient (sign(gm) should be -1)
- for t in range(1, 4):
+ for t in range(1, 8):
if t < 5:
update.run()
else:
@@ -222,7 +222,7 @@ class AddSignTest(test.TestCase):
var0_np, m0 = addsign_update_numpy(
var0_np,
- grads0_np,
+ grads0_np if t < 5 else -grads0_np,
m0,
learning_rate,
alpha=alpha,
@@ -232,7 +232,7 @@ class AddSignTest(test.TestCase):
)
var1_np, m1 = addsign_update_numpy(
var1_np,
- grads1_np,
+ grads1_np if t < 5 else -grads1_np,
m1,
learning_rate,
alpha=alpha,
diff --git a/tensorflow/contrib/opt/python/training/ggt.py b/tensorflow/contrib/opt/python/training/ggt.py
index 928c453517..cae952d8f5 100644
--- a/tensorflow/contrib/opt/python/training/ggt.py
+++ b/tensorflow/contrib/opt/python/training/ggt.py
@@ -33,7 +33,7 @@ class GGTOptimizer(optimizer_v2.OptimizerV2):
GGT has an advantage over sgd and adam on large models with poor conditioning,
for example language models and CNNs,
- see [ABCHSZZ 2018]([pdf](https://arxiv.org/pdf/1806.02958.pdf)).
+ see [[ABCHSZZ 2018]](https://arxiv.org/pdf/1806.02958.pdf).
"""
def __init__(self,
diff --git a/tensorflow/contrib/opt/python/training/powersign_test.py b/tensorflow/contrib/opt/python/training/powersign_test.py
index 5214082dd6..0bcf5d230a 100644
--- a/tensorflow/contrib/opt/python/training/powersign_test.py
+++ b/tensorflow/contrib/opt/python/training/powersign_test.py
@@ -216,7 +216,7 @@ class PowerSignTest(test.TestCase):
self.assertAllClose([1.0, 2.0], var0.eval())
self.assertAllClose([3.0, 4.0], var1.eval())
- # Run 3 steps of powersign
+ # Run 7 steps of powersign
# first 4 steps with positive gradient
# last 3 steps with negative gradient (sign(gm) should be -1)
for t in range(1, 8):
diff --git a/tensorflow/contrib/proto/BUILD b/tensorflow/contrib/proto/BUILD
index 3e9b1a0b8d..d45622174f 100644
--- a/tensorflow/contrib/proto/BUILD
+++ b/tensorflow/contrib/proto/BUILD
@@ -19,9 +19,7 @@ py_library(
py_library(
name = "proto_pip",
- data = [
- "//tensorflow/contrib/proto/python/kernel_tests:test_messages",
- ] + if_static(
+ data = if_static(
[],
otherwise = ["//tensorflow/contrib/proto/python/kernel_tests:libtestexample.so"],
),
diff --git a/tensorflow/contrib/proto/python/kernel_tests/BUILD b/tensorflow/contrib/proto/python/kernel_tests/BUILD
index a380a131f8..3c6fde23d2 100644
--- a/tensorflow/contrib/proto/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/proto/python/kernel_tests/BUILD
@@ -4,45 +4,18 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
-# Much of the work in this BUILD file actually happens in the corresponding
-# build_defs.bzl, which creates an individual testcase for each example .pbtxt
-# file in this directory.
-#
-load(":build_defs.bzl", "decode_proto_test_suite")
-load(":build_defs.bzl", "encode_proto_test_suite")
-
-# This expands to a tf_py_test for each test file.
-# It defines the test_suite :decode_proto_op_tests.
-decode_proto_test_suite(
- name = "decode_proto_tests",
- examples = glob(["*.pbtxt"]),
-)
-
-# This expands to a tf_py_test for each test file.
-# It defines the test_suite :encode_proto_op_tests.
-encode_proto_test_suite(
- name = "encode_proto_tests",
- examples = glob(["*.pbtxt"]),
-)
-
-# Below here are tests that are not tied to an example text proto.
-filegroup(
- name = "test_messages",
- srcs = glob(["*.pbtxt"]),
-)
-
load("//tensorflow:tensorflow.bzl", "tf_py_test")
load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object")
load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static")
load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library")
tf_py_test(
- name = "decode_proto_fail_test",
+ name = "decode_proto_op_test",
size = "small",
- srcs = ["decode_proto_fail_test.py"],
+ srcs = ["decode_proto_op_test.py"],
additional_deps = [
+ ":decode_proto_op_test_base",
":py_test_deps",
- "//third_party/py/numpy",
"//tensorflow/contrib/proto:proto",
"//tensorflow/contrib/proto/python/ops:decode_proto_op_py",
],
@@ -56,20 +29,63 @@ tf_py_test(
],
)
+tf_py_test(
+ name = "encode_proto_op_test",
+ size = "small",
+ srcs = ["encode_proto_op_test.py"],
+ additional_deps = [
+ ":encode_proto_op_test_base",
+ ":py_test_deps",
+ "//tensorflow/contrib/proto:proto",
+ "//tensorflow/contrib/proto/python/ops:decode_proto_op_py",
+ "//tensorflow/contrib/proto/python/ops:encode_proto_op_py",
+ ],
+ data = if_static(
+ [],
+ otherwise = [":libtestexample.so"],
+ ),
+ tags = [
+ "no_pip", # TODO(b/78026780)
+ "no_windows", # TODO(b/78028010)
+ ],
+)
+
+py_library(
+ name = "proto_op_test_base",
+ testonly = 1,
+ srcs = ["proto_op_test_base.py"],
+ deps = [
+ ":test_example_proto_py",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
py_library(
- name = "test_case",
- srcs = ["test_case.py"],
- deps = ["//tensorflow/python:client_testlib"],
+ name = "decode_proto_op_test_base",
+ testonly = 1,
+ srcs = ["decode_proto_op_test_base.py"],
+ deps = [
+ ":proto_op_test_base",
+ ":test_example_proto_py",
+ "//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
+ ],
)
py_library(
- name = "py_test_deps",
+ name = "encode_proto_op_test_base",
+ testonly = 1,
+ srcs = ["encode_proto_op_test_base.py"],
deps = [
- ":test_case",
+ ":proto_op_test_base",
":test_example_proto_py",
+ "//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
],
)
+py_library(name = "py_test_deps")
+
tf_proto_library(
name = "test_example_proto",
srcs = ["test_example.proto"],
diff --git a/tensorflow/contrib/proto/python/kernel_tests/build_defs.bzl b/tensorflow/contrib/proto/python/kernel_tests/build_defs.bzl
deleted file mode 100644
index f425601691..0000000000
--- a/tensorflow/contrib/proto/python/kernel_tests/build_defs.bzl
+++ /dev/null
@@ -1,89 +0,0 @@
-"""BUILD rules for generating file-driven proto test cases.
-
-The decode_proto_test_suite() and encode_proto_test_suite() rules take a list
-of text protos and generates a tf_py_test() for each one.
-"""
-
-load("//tensorflow:tensorflow.bzl", "tf_py_test")
-load("//tensorflow:tensorflow.bzl", "register_extension_info")
-load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static")
-
-def _test_name(test, path):
- return "%s_%s_test" % (test, path.split("/")[-1].split(".")[0])
-
-def decode_proto_test_suite(name, examples):
- """Build the decode_proto py_test for each test filename."""
- for test_filename in examples:
- tf_py_test(
- name = _test_name("decode_proto", test_filename),
- srcs = ["decode_proto_op_test.py"],
- size = "small",
- data = [test_filename] + if_static(
- [],
- otherwise = [":libtestexample.so"],
- ),
- main = "decode_proto_op_test.py",
- args = [
- "--message_text_file=\"%s/%s\"" % (native.package_name(), test_filename),
- ],
- additional_deps = [
- ":py_test_deps",
- "//third_party/py/numpy",
- "//tensorflow/contrib/proto:proto",
- "//tensorflow/contrib/proto/python/ops:decode_proto_op_py",
- ],
- tags = [
- "no_pip", # TODO(b/78026780)
- "no_windows", # TODO(b/78028010)
- ],
- )
- native.test_suite(
- name = name,
- tests = [":" + _test_name("decode_proto", test_filename)
- for test_filename in examples],
- )
-
-def encode_proto_test_suite(name, examples):
- """Build the encode_proto py_test for each test filename."""
- for test_filename in examples:
- tf_py_test(
- name = _test_name("encode_proto", test_filename),
- srcs = ["encode_proto_op_test.py"],
- size = "small",
- data = [test_filename] + if_static(
- [],
- otherwise = [":libtestexample.so"],
- ),
- main = "encode_proto_op_test.py",
- args = [
- "--message_text_file=\"%s/%s\"" % (native.package_name(), test_filename),
- ],
- additional_deps = [
- ":py_test_deps",
- "//third_party/py/numpy",
- "//tensorflow/contrib/proto:proto",
- "//tensorflow/contrib/proto/python/ops:decode_proto_op_py",
- "//tensorflow/contrib/proto/python/ops:encode_proto_op_py",
- ],
- tags = [
- "no_pip", # TODO(b/78026780)
- "no_windows", # TODO(b/78028010)
- ],
- )
- native.test_suite(
- name = name,
- tests = [":" + _test_name("encode_proto", test_filename)
- for test_filename in examples],
- )
-
-register_extension_info(
- extension_name = "decode_proto_test_suite",
- label_regex_map = {
- "deps": "deps:decode_example_.*",
- })
-
-register_extension_info(
- extension_name = "encode_proto_test_suite",
- label_regex_map = {
- "deps": "deps:encode_example_.*",
- })
diff --git a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_fail_test.py b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_fail_test.py
deleted file mode 100644
index 5298342ee7..0000000000
--- a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_fail_test.py
+++ /dev/null
@@ -1,68 +0,0 @@
-# =============================================================================
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# =============================================================================
-
-# Python3 preparedness imports.
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.proto.python.kernel_tests import test_case
-from tensorflow.contrib.proto.python.ops import decode_proto_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.platform import test
-
-
-class DecodeProtoFailTest(test_case.ProtoOpTestCase):
- """Test failure cases for DecodeToProto."""
-
- def _TestCorruptProtobuf(self, sanitize):
- """Test failure cases for DecodeToProto."""
-
- # The goal here is to check the error reporting.
- # Testing against a variety of corrupt protobufs is
- # done by fuzzing.
- corrupt_proto = 'This is not a binary protobuf'
-
- # Numpy silently truncates the strings if you don't specify dtype=object.
- batch = np.array(corrupt_proto, dtype=object)
- msg_type = 'tensorflow.contrib.proto.TestCase'
- field_names = ['sizes']
- field_types = [dtypes.int32]
-
- with self.test_session() as sess:
- ctensor, vtensor = decode_proto_op.decode_proto(
- batch,
- message_type=msg_type,
- field_names=field_names,
- output_types=field_types,
- sanitize=sanitize)
- with self.assertRaisesRegexp(errors.DataLossError,
- 'Unable to parse binary protobuf'
- '|Failed to consume entire buffer'):
- _ = sess.run([ctensor] + vtensor)
-
- def testCorrupt(self):
- self._TestCorruptProtobuf(sanitize=False)
-
- def testSanitizerCorrupt(self):
- self._TestCorruptProtobuf(sanitize=True)
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py
index d1c13c82bc..934035ec4c 100644
--- a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py
+++ b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py
@@ -13,287 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
-"""Table-driven test for decode_proto op.
+"""Tests for decode_proto op."""
-This test is run once with each of the *.TestCase.pbtxt files
-in the test directory.
-"""
# Python3 preparedness imports.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import numpy as np
-
-from google.protobuf import text_format
-
-from tensorflow.contrib.proto.python.kernel_tests import test_case
-from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2
+from tensorflow.contrib.proto.python.kernel_tests import decode_proto_op_test_base as test_base
from tensorflow.contrib.proto.python.ops import decode_proto_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.platform import flags
from tensorflow.python.platform import test
-FLAGS = flags.FLAGS
-
-flags.DEFINE_string('message_text_file', None,
- 'A file containing a text serialized TestCase protobuf.')
-
-
-class DecodeProtoOpTest(test_case.ProtoOpTestCase):
-
- def _compareValues(self, fd, vs, evs):
- """Compare lists/arrays of field values."""
-
- if len(vs) != len(evs):
- self.fail('Field %s decoded %d outputs, expected %d' %
- (fd.name, len(vs), len(evs)))
- for i, ev in enumerate(evs):
- # Special case fuzzy match for float32. TensorFlow seems to mess with
- # MAX_FLT slightly and the test doesn't work otherwise.
- # TODO(nix): ask on TF list about why MAX_FLT doesn't pass through.
- if fd.cpp_type == fd.CPPTYPE_FLOAT:
- # Numpy isclose() is better than assertIsClose() which uses an absolute
- # value comparison.
- self.assertTrue(
- np.isclose(vs[i], ev), 'expected %r, actual %r' % (ev, vs[i]))
- elif fd.cpp_type == fd.CPPTYPE_STRING:
- # In Python3 string tensor values will be represented as bytes, so we
- # reencode the proto values to match that.
- self.assertEqual(vs[i], ev.encode('ascii'))
- else:
- # Doubles and other types pass through unscathed.
- self.assertEqual(vs[i], ev)
-
- def _compareRepeatedPrimitiveValue(self, batch_shape, sizes, fields,
- field_dict):
- """Compare protos of type RepeatedPrimitiveValue.
-
- Args:
- batch_shape: the shape of the input tensor of serialized messages.
- sizes: int matrix of repeat counts returned by decode_proto
- fields: list of test_example_pb2.FieldSpec (types and expected values)
- field_dict: map from field names to decoded numpy tensors of values
- """
-
- # Check that expected values match.
- for field in fields:
- values = field_dict[field.name]
- self.assertEqual(dtypes.as_dtype(values.dtype), field.dtype)
-
- fd = field.expected.DESCRIPTOR.fields_by_name[field.name]
-
- # Values has the same shape as the input plus an extra
- # dimension for repeats.
- self.assertEqual(list(values.shape)[:-1], batch_shape)
-
- # Nested messages are represented as TF strings, requiring
- # some special handling.
- if field.name == 'message_value':
- vs = []
- for buf in values.flat:
- msg = test_example_pb2.PrimitiveValue()
- msg.ParseFromString(buf)
- vs.append(msg)
- evs = getattr(field.expected, field.name)
- if len(vs) != len(evs):
- self.fail('Field %s decoded %d outputs, expected %d' %
- (fd.name, len(vs), len(evs)))
- for v, ev in zip(vs, evs):
- self.assertEqual(v, ev)
- continue
-
- # This can be a little confusing. For testing we are using
- # RepeatedPrimitiveValue in two ways: it's the proto that we
- # decode for testing, and it's used in the expected value as a
- # union type. The two cases are slightly different: this is the
- # second case.
- # We may be fetching the uint64_value from the test proto, but
- # in the expected proto we store it in the int64_value field
- # because TensorFlow doesn't support unsigned int64.
- tf_type_to_primitive_value_field = {
- dtypes.float32:
- 'float_value',
- dtypes.float64:
- 'double_value',
- dtypes.int32:
- 'int32_value',
- dtypes.uint8:
- 'uint8_value',
- dtypes.int8:
- 'int8_value',
- dtypes.string:
- 'string_value',
- dtypes.int64:
- 'int64_value',
- dtypes.bool:
- 'bool_value',
- # Unhandled TensorFlow types:
- # DT_INT16 DT_COMPLEX64 DT_QINT8 DT_QUINT8 DT_QINT32
- # DT_BFLOAT16 DT_QINT16 DT_QUINT16 DT_UINT16
- }
- tf_field_name = tf_type_to_primitive_value_field.get(field.dtype)
- if tf_field_name is None:
- self.fail('Unhandled tensorflow type %d' % field.dtype)
-
- self._compareValues(fd, values.flat,
- getattr(field.expected, tf_field_name))
-
- def _runDecodeProtoTests(self, fields, case_sizes, batch_shape, batch,
- message_type, message_format, sanitize,
- force_disordered=False):
- """Run decode tests on a batch of messages.
-
- Args:
- fields: list of test_example_pb2.FieldSpec (types and expected values)
- case_sizes: expected sizes array
- batch_shape: the shape of the input tensor of serialized messages
- batch: list of serialized messages
- message_type: descriptor name for messages
- message_format: format of messages, 'text' or 'binary'
- sanitize: whether to sanitize binary protobuf inputs
- force_disordered: whether to force fields encoded out of order.
- """
-
- if force_disordered:
- # Exercise code path that handles out-of-order fields by prepending extra
- # fields with tag numbers higher than any real field. Note that this won't
- # work with sanitization because that forces reserialization using a
- # trusted decoder and encoder.
- assert not sanitize
- extra_fields = test_example_pb2.ExtraFields()
- extra_fields.string_value = 'IGNORE ME'
- extra_fields.bool_value = False
- extra_msg = extra_fields.SerializeToString()
- batch = [extra_msg + msg for msg in batch]
-
- # Numpy silently truncates the strings if you don't specify dtype=object.
- batch = np.array(batch, dtype=object)
- batch = np.reshape(batch, batch_shape)
-
- field_names = [f.name for f in fields]
- output_types = [f.dtype for f in fields]
-
- with self.test_session() as sess:
- sizes, vtensor = decode_proto_op.decode_proto(
- batch,
- message_type=message_type,
- field_names=field_names,
- output_types=output_types,
- message_format=message_format,
- sanitize=sanitize)
-
- vlist = sess.run([sizes] + vtensor)
- sizes = vlist[0]
- # Values is a list of tensors, one for each field.
- value_tensors = vlist[1:]
-
- # Check that the repeat sizes are correct.
- self.assertTrue(
- np.all(np.array(sizes.shape) == batch_shape + [len(field_names)]))
-
- # Check that the decoded sizes match the expected sizes.
- self.assertEqual(len(sizes.flat), len(case_sizes))
- self.assertTrue(
- np.all(sizes.flat == np.array(
- case_sizes, dtype=np.int32)))
-
- field_dict = dict(zip(field_names, value_tensors))
-
- self._compareRepeatedPrimitiveValue(batch_shape, sizes, fields,
- field_dict)
-
- def testBinary(self):
- with open(FLAGS.message_text_file, 'r') as fp:
- case = text_format.Parse(fp.read(), test_example_pb2.TestCase())
-
- batch = [primitive.SerializeToString() for primitive in case.primitive]
- self._runDecodeProtoTests(
- case.field,
- case.sizes,
- list(case.shape),
- batch,
- 'tensorflow.contrib.proto.RepeatedPrimitiveValue',
- 'binary',
- sanitize=False)
-
- def testBinaryDisordered(self):
- with open(FLAGS.message_text_file, 'r') as fp:
- case = text_format.Parse(fp.read(), test_example_pb2.TestCase())
-
- batch = [primitive.SerializeToString() for primitive in case.primitive]
- self._runDecodeProtoTests(
- case.field,
- case.sizes,
- list(case.shape),
- batch,
- 'tensorflow.contrib.proto.RepeatedPrimitiveValue',
- 'binary',
- sanitize=False,
- force_disordered=True)
-
- def testPacked(self):
- with open(FLAGS.message_text_file, 'r') as fp:
- case = text_format.Parse(fp.read(), test_example_pb2.TestCase())
-
- # Now try with the packed serialization.
- # We test the packed representations by loading the same test cases
- # using PackedPrimitiveValue instead of RepeatedPrimitiveValue.
- # To do this we rely on the text format being the same for packed and
- # unpacked fields, and reparse the test message using the packed version
- # of the proto.
- packed_batch = [
- # Note: float_format='.17g' is necessary to ensure preservation of
- # doubles and floats in text format.
- text_format.Parse(
- text_format.MessageToString(
- primitive, float_format='.17g'),
- test_example_pb2.PackedPrimitiveValue()).SerializeToString()
- for primitive in case.primitive
- ]
-
- self._runDecodeProtoTests(
- case.field,
- case.sizes,
- list(case.shape),
- packed_batch,
- 'tensorflow.contrib.proto.PackedPrimitiveValue',
- 'binary',
- sanitize=False)
-
- def testText(self):
- with open(FLAGS.message_text_file, 'r') as fp:
- case = text_format.Parse(fp.read(), test_example_pb2.TestCase())
-
- # Note: float_format='.17g' is necessary to ensure preservation of
- # doubles and floats in text format.
- text_batch = [
- text_format.MessageToString(
- primitive, float_format='.17g') for primitive in case.primitive
- ]
-
- self._runDecodeProtoTests(
- case.field,
- case.sizes,
- list(case.shape),
- text_batch,
- 'tensorflow.contrib.proto.RepeatedPrimitiveValue',
- 'text',
- sanitize=False)
- def testSanitizerGood(self):
- with open(FLAGS.message_text_file, 'r') as fp:
- case = text_format.Parse(fp.read(), test_example_pb2.TestCase())
+class DecodeProtoOpTest(test_base.DecodeProtoOpTestBase):
- batch = [primitive.SerializeToString() for primitive in case.primitive]
- self._runDecodeProtoTests(
- case.field,
- case.sizes,
- list(case.shape),
- batch,
- 'tensorflow.contrib.proto.RepeatedPrimitiveValue',
- 'binary',
- sanitize=True)
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ super(DecodeProtoOpTest, self).__init__(decode_proto_op, methodName)
if __name__ == '__main__':
diff --git a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py
new file mode 100644
index 0000000000..5f7f510352
--- /dev/null
+++ b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py
@@ -0,0 +1,310 @@
+# =============================================================================
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Tests for decode_proto op."""
+
+# Python3 preparedness imports.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy as np
+
+
+from google.protobuf import text_format
+
+from tensorflow.contrib.proto.python.kernel_tests import proto_op_test_base as test_base
+from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+
+
+class DecodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase):
+ """Base class for testing proto decoding ops."""
+
+ def __init__(self, decode_module, methodName='runTest'): # pylint: disable=invalid-name
+ """DecodeProtoOpTestBase initializer.
+
+ Args:
+ decode_module: a module containing the `decode_proto_op` method
+ methodName: the name of the test method (same as for test.TestCase)
+ """
+
+ super(DecodeProtoOpTestBase, self).__init__(methodName)
+ self._decode_module = decode_module
+
+ def _compareValues(self, fd, vs, evs):
+ """Compare lists/arrays of field values."""
+
+ if len(vs) != len(evs):
+ self.fail('Field %s decoded %d outputs, expected %d' %
+ (fd.name, len(vs), len(evs)))
+ for i, ev in enumerate(evs):
+ # Special case fuzzy match for float32. TensorFlow seems to mess with
+ # MAX_FLT slightly and the test doesn't work otherwise.
+ # TODO(nix): ask on TF list about why MAX_FLT doesn't pass through.
+ if fd.cpp_type == fd.CPPTYPE_FLOAT:
+ # Numpy isclose() is better than assertIsClose() which uses an absolute
+ # value comparison.
+ self.assertTrue(
+ np.isclose(vs[i], ev), 'expected %r, actual %r' % (ev, vs[i]))
+ elif fd.cpp_type == fd.CPPTYPE_STRING:
+ # In Python3 string tensor values will be represented as bytes, so we
+ # reencode the proto values to match that.
+ self.assertEqual(vs[i], ev.encode('ascii'))
+ else:
+ # Doubles and other types pass through unscathed.
+ self.assertEqual(vs[i], ev)
+
+ def _compareProtos(self, batch_shape, sizes, fields, field_dict):
+ """Compare protos of type TestValue.
+
+ Args:
+ batch_shape: the shape of the input tensor of serialized messages.
+ sizes: int matrix of repeat counts returned by decode_proto
+ fields: list of test_example_pb2.FieldSpec (types and expected values)
+ field_dict: map from field names to decoded numpy tensors of values
+ """
+
+ # Check that expected values match.
+ for field in fields:
+ values = field_dict[field.name]
+ self.assertEqual(dtypes.as_dtype(values.dtype), field.dtype)
+
+ fd = field.value.DESCRIPTOR.fields_by_name[field.name]
+
+ # Values has the same shape as the input plus an extra
+ # dimension for repeats.
+ self.assertEqual(list(values.shape)[:-1], batch_shape)
+
+ # Nested messages are represented as TF strings, requiring
+ # some special handling.
+ if field.name == 'message_value':
+ vs = []
+ for buf in values.flat:
+ msg = test_example_pb2.PrimitiveValue()
+ msg.ParseFromString(buf)
+ vs.append(msg)
+ evs = getattr(field.value, field.name)
+ if len(vs) != len(evs):
+ self.fail('Field %s decoded %d outputs, expected %d' %
+ (fd.name, len(vs), len(evs)))
+ for v, ev in zip(vs, evs):
+ self.assertEqual(v, ev)
+ continue
+
+ # This can be a little confusing. For testing we are using TestValue in
+ # two ways: it's the proto that we decode for testing, and it's used in
+ # the expected value as a union type.
+ #
+ # The two cases are slightly different: this is the second case. We may be
+ # fetching the uint64_value from the test proto, but in the expected proto
+ # we store it in the int64_value field because TensorFlow doesn't support
+ # unsigned int64.
+ tf_type_to_primitive_value_field = {
+ dtypes.float32:
+ 'float_value',
+ dtypes.float64:
+ 'double_value',
+ dtypes.int32:
+ 'int32_value',
+ dtypes.uint8:
+ 'uint8_value',
+ dtypes.int8:
+ 'int8_value',
+ dtypes.string:
+ 'string_value',
+ dtypes.int64:
+ 'int64_value',
+ dtypes.bool:
+ 'bool_value',
+ # Unhandled TensorFlow types:
+ # DT_INT16 DT_COMPLEX64 DT_QINT8 DT_QUINT8 DT_QINT32
+ # DT_BFLOAT16 DT_QINT16 DT_QUINT16 DT_UINT16
+ }
+ tf_field_name = tf_type_to_primitive_value_field.get(field.dtype)
+ if tf_field_name is None:
+ self.fail('Unhandled tensorflow type %d' % field.dtype)
+
+ self._compareValues(fd, values.flat,
+ getattr(field.value, tf_field_name))
+
+ def _runDecodeProtoTests(self, fields, case_sizes, batch_shape, batch,
+ message_type, message_format, sanitize,
+ force_disordered=False):
+ """Run decode tests on a batch of messages.
+
+ Args:
+ fields: list of test_example_pb2.FieldSpec (types and expected values)
+ case_sizes: expected sizes array
+ batch_shape: the shape of the input tensor of serialized messages
+ batch: list of serialized messages
+ message_type: descriptor name for messages
+ message_format: format of messages, 'text' or 'binary'
+ sanitize: whether to sanitize binary protobuf inputs
+ force_disordered: whether to force fields encoded out of order.
+ """
+
+ if force_disordered:
+ # Exercise code path that handles out-of-order fields by prepending extra
+ # fields with tag numbers higher than any real field. Note that this won't
+ # work with sanitization because that forces reserialization using a
+ # trusted decoder and encoder.
+ assert not sanitize
+ extra_fields = test_example_pb2.ExtraFields()
+ extra_fields.string_value = 'IGNORE ME'
+ extra_fields.bool_value = False
+ extra_msg = extra_fields.SerializeToString()
+ batch = [extra_msg + msg for msg in batch]
+
+ # Numpy silently truncates the strings if you don't specify dtype=object.
+ batch = np.array(batch, dtype=object)
+ batch = np.reshape(batch, batch_shape)
+
+ field_names = [f.name for f in fields]
+ output_types = [f.dtype for f in fields]
+
+ with self.test_session() as sess:
+ sizes, vtensor = self._decode_module.decode_proto(
+ batch,
+ message_type=message_type,
+ field_names=field_names,
+ output_types=output_types,
+ message_format=message_format,
+ sanitize=sanitize)
+
+ vlist = sess.run([sizes] + vtensor)
+ sizes = vlist[0]
+ # Values is a list of tensors, one for each field.
+ value_tensors = vlist[1:]
+
+ # Check that the repeat sizes are correct.
+ self.assertTrue(
+ np.all(np.array(sizes.shape) == batch_shape + [len(field_names)]))
+
+ # Check that the decoded sizes match the expected sizes.
+ self.assertEqual(len(sizes.flat), len(case_sizes))
+ self.assertTrue(
+ np.all(sizes.flat == np.array(
+ case_sizes, dtype=np.int32)))
+
+ field_dict = dict(zip(field_names, value_tensors))
+
+ self._compareProtos(batch_shape, sizes, fields, field_dict)
+
+ @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters())
+ def testBinary(self, case):
+ batch = [value.SerializeToString() for value in case.values]
+ self._runDecodeProtoTests(
+ case.fields,
+ case.sizes,
+ list(case.shapes),
+ batch,
+ 'tensorflow.contrib.proto.TestValue',
+ 'binary',
+ sanitize=False)
+
+ @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters())
+ def testBinaryDisordered(self, case):
+ batch = [value.SerializeToString() for value in case.values]
+ self._runDecodeProtoTests(
+ case.fields,
+ case.sizes,
+ list(case.shapes),
+ batch,
+ 'tensorflow.contrib.proto.TestValue',
+ 'binary',
+ sanitize=False,
+ force_disordered=True)
+
+ @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters())
+ def testPacked(self, case):
+ # Now try with the packed serialization.
+ #
+ # We test the packed representations by loading the same test case using
+ # PackedTestValue instead of TestValue. To do this we rely on the text
+ # format being the same for packed and unpacked fields, and reparse the
+ # test message using the packed version of the proto.
+ packed_batch = [
+ # Note: float_format='.17g' is necessary to ensure preservation of
+ # doubles and floats in text format.
+ text_format.Parse(
+ text_format.MessageToString(
+ value, float_format='.17g'),
+ test_example_pb2.PackedTestValue()).SerializeToString()
+ for value in case.values
+ ]
+
+ self._runDecodeProtoTests(
+ case.fields,
+ case.sizes,
+ list(case.shapes),
+ packed_batch,
+ 'tensorflow.contrib.proto.PackedTestValue',
+ 'binary',
+ sanitize=False)
+
+ @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters())
+ def testText(self, case):
+ # Note: float_format='.17g' is necessary to ensure preservation of
+ # doubles and floats in text format.
+ text_batch = [
+ text_format.MessageToString(
+ value, float_format='.17g') for value in case.values
+ ]
+
+ self._runDecodeProtoTests(
+ case.fields,
+ case.sizes,
+ list(case.shapes),
+ text_batch,
+ 'tensorflow.contrib.proto.TestValue',
+ 'text',
+ sanitize=False)
+
+ @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters())
+ def testSanitizerGood(self, case):
+ batch = [value.SerializeToString() for value in case.values]
+ self._runDecodeProtoTests(
+ case.fields,
+ case.sizes,
+ list(case.shapes),
+ batch,
+ 'tensorflow.contrib.proto.TestValue',
+ 'binary',
+ sanitize=True)
+
+ @parameterized.parameters((False), (True))
+ def testCorruptProtobuf(self, sanitize):
+ corrupt_proto = 'This is not a binary protobuf'
+
+ # Numpy silently truncates the strings if you don't specify dtype=object.
+ batch = np.array(corrupt_proto, dtype=object)
+ msg_type = 'tensorflow.contrib.proto.TestCase'
+ field_names = ['sizes']
+ field_types = [dtypes.int32]
+
+ with self.test_session() as sess:
+ ctensor, vtensor = self._decode_module.decode_proto(
+ batch,
+ message_type=msg_type,
+ field_names=field_names,
+ output_types=field_types,
+ sanitize=sanitize)
+ with self.assertRaisesRegexp(errors.DataLossError,
+ 'Unable to parse binary protobuf'
+ '|Failed to consume entire buffer'):
+ _ = sess.run([ctensor] + vtensor)
diff --git a/tensorflow/contrib/proto/python/kernel_tests/defaut_values.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/defaut_values.TestCase.pbtxt
deleted file mode 100644
index 4e31681907..0000000000
--- a/tensorflow/contrib/proto/python/kernel_tests/defaut_values.TestCase.pbtxt
+++ /dev/null
@@ -1,94 +0,0 @@
-primitive {
- # No fields specified, so we get all defaults
-}
-shape: 1
-sizes: 0
-field {
- name: "double_default"
- dtype: DT_DOUBLE
- expected { double_value: 1.0 }
-}
-sizes: 0
-field {
- name: "float_default"
- dtype: DT_DOUBLE # Try casting the float field to double.
- expected { double_value: 2.0 }
-}
-sizes: 0
-field {
- name: "int64_default"
- dtype: DT_INT64
- expected { int64_value: 3 }
-}
-sizes: 0
-field {
- name: "uint64_default"
- dtype: DT_INT64
- expected { int64_value: 4 }
-}
-sizes: 0
-field {
- name: "int32_default"
- dtype: DT_INT32
- expected { int32_value: 5 }
-}
-sizes: 0
-field {
- name: "fixed64_default"
- dtype: DT_INT64
- expected { int64_value: 6 }
-}
-sizes: 0
-field {
- name: "fixed32_default"
- dtype: DT_INT32
- expected { int32_value: 7 }
-}
-sizes: 0
-field {
- name: "bool_default"
- dtype: DT_BOOL
- expected { bool_value: true }
-}
-sizes: 0
-field {
- name: "string_default"
- dtype: DT_STRING
- expected { string_value: "a" }
-}
-sizes: 0
-field {
- name: "bytes_default"
- dtype: DT_STRING
- expected { string_value: "a longer default string" }
-}
-sizes: 0
-field {
- name: "uint32_default"
- dtype: DT_INT32
- expected { int32_value: -1 }
-}
-sizes: 0
-field {
- name: "sfixed32_default"
- dtype: DT_INT32
- expected { int32_value: 10 }
-}
-sizes: 0
-field {
- name: "sfixed64_default"
- dtype: DT_INT64
- expected { int64_value: 11 }
-}
-sizes: 0
-field {
- name: "sint32_default"
- dtype: DT_INT32
- expected { int32_value: 12 }
-}
-sizes: 0
-field {
- name: "sint64_default"
- dtype: DT_INT64
- expected { int64_value: 13 }
-}
diff --git a/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py b/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py
index 30e58e6336..fc5cd25d43 100644
--- a/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py
+++ b/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py
@@ -13,167 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
-"""Table-driven test for encode_proto op.
+"""Tests for encode_proto op."""
-This test is run once with each of the *.TestCase.pbtxt files
-in the test directory.
-
-It tests that encode_proto is a lossless inverse of decode_proto
-(for the specified fields).
-"""
# Python3 readiness boilerplate
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import numpy as np
-
-from google.protobuf import text_format
-
-from tensorflow.contrib.proto.python.kernel_tests import test_case
-from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2
+from tensorflow.contrib.proto.python.kernel_tests import encode_proto_op_test_base as test_base
from tensorflow.contrib.proto.python.ops import decode_proto_op
from tensorflow.contrib.proto.python.ops import encode_proto_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.ops import array_ops
-from tensorflow.python.platform import flags
from tensorflow.python.platform import test
-FLAGS = flags.FLAGS
-
-flags.DEFINE_string('message_text_file', None,
- 'A file containing a text serialized TestCase protobuf.')
-
-
-class EncodeProtoOpTest(test_case.ProtoOpTestCase):
-
- def testBadInputs(self):
- # Invalid field name
- with self.test_session():
- with self.assertRaisesOpError('Unknown field: non_existent_field'):
- encode_proto_op.encode_proto(
- sizes=[[1]],
- values=[np.array([[0.0]], dtype=np.int32)],
- message_type='tensorflow.contrib.proto.RepeatedPrimitiveValue',
- field_names=['non_existent_field']).eval()
-
- # Incorrect types.
- with self.test_session():
- with self.assertRaisesOpError(
- 'Incompatible type for field double_value.'):
- encode_proto_op.encode_proto(
- sizes=[[1]],
- values=[np.array([[0.0]], dtype=np.int32)],
- message_type='tensorflow.contrib.proto.RepeatedPrimitiveValue',
- field_names=['double_value']).eval()
-
- # Incorrect shapes of sizes.
- with self.test_session():
- with self.assertRaisesOpError(
- r'sizes should be batch_size \+ \[len\(field_names\)\]'):
- sizes = array_ops.placeholder(dtypes.int32)
- values = array_ops.placeholder(dtypes.float64)
- encode_proto_op.encode_proto(
- sizes=sizes,
- values=[values],
- message_type='tensorflow.contrib.proto.RepeatedPrimitiveValue',
- field_names=['double_value']).eval(feed_dict={
- sizes: [[[0, 0]]],
- values: [[0.0]]
- })
-
- # Inconsistent shapes of values.
- with self.test_session():
- with self.assertRaisesOpError(
- 'Values must match up to the last dimension'):
- sizes = array_ops.placeholder(dtypes.int32)
- values1 = array_ops.placeholder(dtypes.float64)
- values2 = array_ops.placeholder(dtypes.int32)
- (encode_proto_op.encode_proto(
- sizes=[[1, 1]],
- values=[values1, values2],
- message_type='tensorflow.contrib.proto.RepeatedPrimitiveValue',
- field_names=['double_value', 'int32_value']).eval(feed_dict={
- values1: [[0.0]],
- values2: [[0], [0]]
- }))
-
- def _testRoundtrip(self, in_bufs, message_type, fields):
-
- field_names = [f.name for f in fields]
- out_types = [f.dtype for f in fields]
-
- with self.test_session() as sess:
- sizes, field_tensors = decode_proto_op.decode_proto(
- in_bufs,
- message_type=message_type,
- field_names=field_names,
- output_types=out_types)
-
- out_tensors = encode_proto_op.encode_proto(
- sizes,
- field_tensors,
- message_type=message_type,
- field_names=field_names)
-
- out_bufs, = sess.run([out_tensors])
-
- # Check that the re-encoded tensor has the same shape.
- self.assertEqual(in_bufs.shape, out_bufs.shape)
-
- # Compare the input and output.
- for in_buf, out_buf in zip(in_bufs.flat, out_bufs.flat):
- in_obj = test_example_pb2.RepeatedPrimitiveValue()
- in_obj.ParseFromString(in_buf)
-
- out_obj = test_example_pb2.RepeatedPrimitiveValue()
- out_obj.ParseFromString(out_buf)
-
- # Check that the deserialized objects are identical.
- self.assertEqual(in_obj, out_obj)
-
- # Check that the input and output serialized messages are identical.
- # If we fail here, there is a difference in the serialized
- # representation but the new serialization still parses. This could
- # be harmless (a change in map ordering?) or it could be bad (e.g.
- # loss of packing in the encoding).
- self.assertEqual(in_buf, out_buf)
-
- def testRoundtrip(self):
- with open(FLAGS.message_text_file, 'r') as fp:
- case = text_format.Parse(fp.read(), test_example_pb2.TestCase())
-
- in_bufs = [primitive.SerializeToString() for primitive in case.primitive]
-
- # np.array silently truncates strings if you don't specify dtype=object.
- in_bufs = np.reshape(np.array(in_bufs, dtype=object), list(case.shape))
- return self._testRoundtrip(
- in_bufs, 'tensorflow.contrib.proto.RepeatedPrimitiveValue', case.field)
-
- def testRoundtripPacked(self):
- with open(FLAGS.message_text_file, 'r') as fp:
- case = text_format.Parse(fp.read(), test_example_pb2.TestCase())
- # Now try with the packed serialization.
- # We test the packed representations by loading the same test cases
- # using PackedPrimitiveValue instead of RepeatedPrimitiveValue.
- # To do this we rely on the text format being the same for packed and
- # unpacked fields, and reparse the test message using the packed version
- # of the proto.
- in_bufs = [
- # Note: float_format='.17g' is necessary to ensure preservation of
- # doubles and floats in text format.
- text_format.Parse(
- text_format.MessageToString(
- primitive, float_format='.17g'),
- test_example_pb2.PackedPrimitiveValue()).SerializeToString()
- for primitive in case.primitive
- ]
+class EncodeProtoOpTest(test_base.EncodeProtoOpTestBase):
- # np.array silently truncates strings if you don't specify dtype=object.
- in_bufs = np.reshape(np.array(in_bufs, dtype=object), list(case.shape))
- return self._testRoundtrip(
- in_bufs, 'tensorflow.contrib.proto.PackedPrimitiveValue', case.field)
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ super(EncodeProtoOpTest, self).__init__(decode_proto_op, encode_proto_op,
+ methodName)
if __name__ == '__main__':
diff --git a/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test_base.py b/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test_base.py
new file mode 100644
index 0000000000..07dfb924d3
--- /dev/null
+++ b/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test_base.py
@@ -0,0 +1,177 @@
+# =============================================================================
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Table-driven test for encode_proto op.
+
+This test is run once with each of the *.TestCase.pbtxt files
+in the test directory.
+
+It tests that encode_proto is a lossless inverse of decode_proto
+(for the specified fields).
+"""
+# Python3 readiness boilerplate
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy as np
+
+from google.protobuf import text_format
+
+from tensorflow.contrib.proto.python.kernel_tests import proto_op_test_base as test_base
+from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+
+
+class EncodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase):
+ """Base class for testing proto encoding ops."""
+
+ def __init__(self, decode_module, encode_module, methodName='runTest'): # pylint: disable=invalid-name
+ """EncodeProtoOpTestBase initializer.
+
+ Args:
+ decode_module: a module containing the `decode_proto_op` method
+ encode_module: a module containing the `encode_proto_op` method
+ methodName: the name of the test method (same as for test.TestCase)
+ """
+
+ super(EncodeProtoOpTestBase, self).__init__(methodName)
+ self._decode_module = decode_module
+ self._encode_module = encode_module
+
+ def testBadInputs(self):
+ # Invalid field name
+ with self.test_session():
+ with self.assertRaisesOpError('Unknown field: non_existent_field'):
+ self._encode_module.encode_proto(
+ sizes=[[1]],
+ values=[np.array([[0.0]], dtype=np.int32)],
+ message_type='tensorflow.contrib.proto.TestValue',
+ field_names=['non_existent_field']).eval()
+
+ # Incorrect types.
+ with self.test_session():
+ with self.assertRaisesOpError(
+ 'Incompatible type for field double_value.'):
+ self._encode_module.encode_proto(
+ sizes=[[1]],
+ values=[np.array([[0.0]], dtype=np.int32)],
+ message_type='tensorflow.contrib.proto.TestValue',
+ field_names=['double_value']).eval()
+
+ # Incorrect shapes of sizes.
+ with self.test_session():
+ with self.assertRaisesOpError(
+ r'sizes should be batch_size \+ \[len\(field_names\)\]'):
+ sizes = array_ops.placeholder(dtypes.int32)
+ values = array_ops.placeholder(dtypes.float64)
+ self._encode_module.encode_proto(
+ sizes=sizes,
+ values=[values],
+ message_type='tensorflow.contrib.proto.TestValue',
+ field_names=['double_value']).eval(feed_dict={
+ sizes: [[[0, 0]]],
+ values: [[0.0]]
+ })
+
+ # Inconsistent shapes of values.
+ with self.test_session():
+ with self.assertRaisesOpError(
+ 'Values must match up to the last dimension'):
+ sizes = array_ops.placeholder(dtypes.int32)
+ values1 = array_ops.placeholder(dtypes.float64)
+ values2 = array_ops.placeholder(dtypes.int32)
+ (self._encode_module.encode_proto(
+ sizes=[[1, 1]],
+ values=[values1, values2],
+ message_type='tensorflow.contrib.proto.TestValue',
+ field_names=['double_value', 'int32_value']).eval(feed_dict={
+ values1: [[0.0]],
+ values2: [[0], [0]]
+ }))
+
+ def _testRoundtrip(self, in_bufs, message_type, fields):
+
+ field_names = [f.name for f in fields]
+ out_types = [f.dtype for f in fields]
+
+ with self.test_session() as sess:
+ sizes, field_tensors = self._decode_module.decode_proto(
+ in_bufs,
+ message_type=message_type,
+ field_names=field_names,
+ output_types=out_types)
+
+ out_tensors = self._encode_module.encode_proto(
+ sizes,
+ field_tensors,
+ message_type=message_type,
+ field_names=field_names)
+
+ out_bufs, = sess.run([out_tensors])
+
+ # Check that the re-encoded tensor has the same shape.
+ self.assertEqual(in_bufs.shape, out_bufs.shape)
+
+ # Compare the input and output.
+ for in_buf, out_buf in zip(in_bufs.flat, out_bufs.flat):
+ in_obj = test_example_pb2.TestValue()
+ in_obj.ParseFromString(in_buf)
+
+ out_obj = test_example_pb2.TestValue()
+ out_obj.ParseFromString(out_buf)
+
+ # Check that the deserialized objects are identical.
+ self.assertEqual(in_obj, out_obj)
+
+ # Check that the input and output serialized messages are identical.
+ # If we fail here, there is a difference in the serialized
+ # representation but the new serialization still parses. This could
+ # be harmless (a change in map ordering?) or it could be bad (e.g.
+ # loss of packing in the encoding).
+ self.assertEqual(in_buf, out_buf)
+
+ @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters())
+ def testRoundtrip(self, case):
+ in_bufs = [value.SerializeToString() for value in case.values]
+
+ # np.array silently truncates strings if you don't specify dtype=object.
+ in_bufs = np.reshape(np.array(in_bufs, dtype=object), list(case.shapes))
+ return self._testRoundtrip(
+ in_bufs, 'tensorflow.contrib.proto.TestValue', case.fields)
+
+ @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters())
+ def testRoundtripPacked(self, case):
+ # Now try with the packed serialization.
+ # We test the packed representations by loading the same test cases using
+ # PackedTestValue instead of TestValue. To do this we rely on the text
+ # format being the same for packed and unpacked fields, and reparse the test
+ # message using the packed version of the proto.
+ in_bufs = [
+ # Note: float_format='.17g' is necessary to ensure preservation of
+ # doubles and floats in text format.
+ text_format.Parse(
+ text_format.MessageToString(
+ value, float_format='.17g'),
+ test_example_pb2.PackedTestValue()).SerializeToString()
+ for value in case.values
+ ]
+
+ # np.array silently truncates strings if you don't specify dtype=object.
+ in_bufs = np.reshape(np.array(in_bufs, dtype=object), list(case.shapes))
+ return self._testRoundtrip(
+ in_bufs, 'tensorflow.contrib.proto.PackedTestValue', case.fields)
diff --git a/tensorflow/contrib/proto/python/kernel_tests/minmax.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/minmax.TestCase.pbtxt
deleted file mode 100644
index b170f89c0f..0000000000
--- a/tensorflow/contrib/proto/python/kernel_tests/minmax.TestCase.pbtxt
+++ /dev/null
@@ -1,161 +0,0 @@
-primitive {
- double_value: -1.7976931348623158e+308
- double_value: 2.2250738585072014e-308
- double_value: 1.7976931348623158e+308
- float_value: -3.402823466e+38
- float_value: 1.175494351e-38
- float_value: 3.402823466e+38
- int64_value: -9223372036854775808
- int64_value: 9223372036854775807
- uint64_value: 0
- uint64_value: 18446744073709551615
- int32_value: -2147483648
- int32_value: 2147483647
- fixed64_value: 0
- fixed64_value: 18446744073709551615
- fixed32_value: 0
- fixed32_value: 4294967295
- bool_value: false
- bool_value: true
- string_value: ""
- string_value: "I refer to the infinite."
- uint32_value: 0
- uint32_value: 4294967295
- sfixed32_value: -2147483648
- sfixed32_value: 2147483647
- sfixed64_value: -9223372036854775808
- sfixed64_value: 9223372036854775807
- sint32_value: -2147483648
- sint32_value: 2147483647
- sint64_value: -9223372036854775808
- sint64_value: 9223372036854775807
-}
-shape: 1
-sizes: 3
-sizes: 3
-sizes: 2
-sizes: 2
-sizes: 2
-sizes: 2
-sizes: 2
-sizes: 2
-sizes: 2
-sizes: 2
-sizes: 2
-sizes: 2
-sizes: 2
-sizes: 2
-field {
- name: "double_value"
- dtype: DT_DOUBLE
- expected {
- double_value: -1.7976931348623158e+308
- double_value: 2.2250738585072014e-308
- double_value: 1.7976931348623158e+308
- }
-}
-field {
- name: "float_value"
- dtype: DT_FLOAT
- expected {
- float_value: -3.402823466e+38
- float_value: 1.175494351e-38
- float_value: 3.402823466e+38
- }
-}
-field {
- name: "int64_value"
- dtype: DT_INT64
- expected {
- int64_value: -9223372036854775808
- int64_value: 9223372036854775807
- }
-}
-field {
- name: "uint64_value"
- dtype: DT_INT64
- expected {
- int64_value: 0
- int64_value: -1
- }
-}
-field {
- name: "int32_value"
- dtype: DT_INT32
- expected {
- int32_value: -2147483648
- int32_value: 2147483647
- }
-}
-field {
- name: "fixed64_value"
- dtype: DT_INT64
- expected {
- int64_value: 0
- int64_value: -1 # unsigned is 18446744073709551615
- }
-}
-field {
- name: "fixed32_value"
- dtype: DT_INT32
- expected {
- int32_value: 0
- int32_value: -1 # unsigned is 4294967295
- }
-}
-field {
- name: "bool_value"
- dtype: DT_BOOL
- expected {
- bool_value: false
- bool_value: true
- }
-}
-field {
- name: "string_value"
- dtype: DT_STRING
- expected {
- string_value: ""
- string_value: "I refer to the infinite."
- }
-}
-field {
- name: "uint32_value"
- dtype: DT_INT32
- expected {
- int32_value: 0
- int32_value: -1 # unsigned is 4294967295
- }
-}
-field {
- name: "sfixed32_value"
- dtype: DT_INT32
- expected {
- int32_value: -2147483648
- int32_value: 2147483647
- }
-}
-field {
- name: "sfixed64_value"
- dtype: DT_INT64
- expected {
- int64_value: -9223372036854775808
- int64_value: 9223372036854775807
- }
-}
-field {
- name: "sint32_value"
- dtype: DT_INT32
- expected {
- int32_value: -2147483648
- int32_value: 2147483647
- }
-}
-field {
- name: "sint64_value"
- dtype: DT_INT64
- expected {
- int64_value: -9223372036854775808
- int64_value: 9223372036854775807
- }
-}
diff --git a/tensorflow/contrib/proto/python/kernel_tests/nested.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/nested.TestCase.pbtxt
deleted file mode 100644
index c664e52851..0000000000
--- a/tensorflow/contrib/proto/python/kernel_tests/nested.TestCase.pbtxt
+++ /dev/null
@@ -1,16 +0,0 @@
-primitive {
- message_value {
- double_value: 23.5
- }
-}
-shape: 1
-sizes: 1
-field {
- name: "message_value"
- dtype: DT_STRING
- expected {
- message_value {
- double_value: 23.5
- }
- }
-}
diff --git a/tensorflow/contrib/proto/python/kernel_tests/optional.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/optional.TestCase.pbtxt
deleted file mode 100644
index 125651d7ea..0000000000
--- a/tensorflow/contrib/proto/python/kernel_tests/optional.TestCase.pbtxt
+++ /dev/null
@@ -1,20 +0,0 @@
-primitive {
- bool_value: true
-}
-shape: 1
-sizes: 1
-sizes: 0
-field {
- name: "bool_value"
- dtype: DT_BOOL
- expected {
- bool_value: true
- }
-}
-field {
- name: "double_value"
- dtype: DT_DOUBLE
- expected {
- double_value: 0.0
- }
-}
diff --git a/tensorflow/contrib/proto/python/kernel_tests/promote_unsigned.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/promote_unsigned.TestCase.pbtxt
deleted file mode 100644
index bc07efc8f3..0000000000
--- a/tensorflow/contrib/proto/python/kernel_tests/promote_unsigned.TestCase.pbtxt
+++ /dev/null
@@ -1,29 +0,0 @@
-primitive {
- fixed32_value: 4294967295
- uint32_value: 4294967295
-}
-shape: 1
-sizes: 1
-field {
- name: "fixed32_value"
- dtype: DT_INT64
- expected {
- int64_value: 4294967295
- }
-}
-sizes: 1
-field {
- name: "uint32_value"
- dtype: DT_INT64
- expected {
- int64_value: 4294967295
- }
-}
-sizes: 0
-field {
- name: "uint32_default"
- dtype: DT_INT64
- expected {
- int64_value: 4294967295 # Comes from an explicitly-specified default
- }
-}
diff --git a/tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py b/tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py
new file mode 100644
index 0000000000..cbc7b3d3f8
--- /dev/null
+++ b/tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py
@@ -0,0 +1,407 @@
+# =============================================================================
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Test case base for testing proto operations."""
+
+# Python3 preparedness imports.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import ctypes as ct
+import os
+
+from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2
+from tensorflow.core.framework import types_pb2
+from tensorflow.python.platform import test
+
+
+class ProtoOpTestBase(test.TestCase):
+ """Base class for testing proto decoding and encoding ops."""
+
+ def __init__(self, methodName="runTest"): # pylint: disable=invalid-name
+ super(ProtoOpTestBase, self).__init__(methodName)
+ lib = os.path.join(os.path.dirname(__file__), "libtestexample.so")
+ if os.path.isfile(lib):
+ ct.cdll.LoadLibrary(lib)
+
+ @staticmethod
+ def named_parameters():
+ return (
+ ("defaults", ProtoOpTestBase.defaults_test_case()),
+ ("minmax", ProtoOpTestBase.minmax_test_case()),
+ ("nested", ProtoOpTestBase.nested_test_case()),
+ ("optional", ProtoOpTestBase.optional_test_case()),
+ ("promote_unsigned", ProtoOpTestBase.promote_unsigned_test_case()),
+ ("ragged", ProtoOpTestBase.ragged_test_case()),
+ ("shaped_batch", ProtoOpTestBase.shaped_batch_test_case()),
+ ("simple", ProtoOpTestBase.simple_test_case()),
+ )
+
+ @staticmethod
+ def defaults_test_case():
+ test_case = test_example_pb2.TestCase()
+ test_case.values.add() # No fields specified, so we get all defaults.
+ test_case.shapes.append(1)
+ test_case.sizes.append(0)
+ field = test_case.fields.add()
+ field.name = "double_value_with_default"
+ field.dtype = types_pb2.DT_DOUBLE
+ field.value.double_value.append(1.0)
+ test_case.sizes.append(0)
+ field = test_case.fields.add()
+ field.name = "float_value_with_default"
+ field.dtype = types_pb2.DT_FLOAT
+ field.value.float_value.append(2.0)
+ test_case.sizes.append(0)
+ field = test_case.fields.add()
+ field.name = "int64_value_with_default"
+ field.dtype = types_pb2.DT_INT64
+ field.value.int64_value.append(3)
+ test_case.sizes.append(0)
+ field = test_case.fields.add()
+ field.name = "sfixed64_value_with_default"
+ field.dtype = types_pb2.DT_INT64
+ field.value.int64_value.append(11)
+ test_case.sizes.append(0)
+ field = test_case.fields.add()
+ field.name = "sint64_value_with_default"
+ field.dtype = types_pb2.DT_INT64
+ field.value.int64_value.append(13)
+ test_case.sizes.append(0)
+ field = test_case.fields.add()
+ field.name = "uint64_value_with_default"
+ field.dtype = types_pb2.DT_INT64
+ field.value.int64_value.append(4)
+ test_case.sizes.append(0)
+ field = test_case.fields.add()
+ field.name = "fixed64_value_with_default"
+ field.dtype = types_pb2.DT_INT64
+ field.value.int64_value.append(6)
+ test_case.sizes.append(0)
+ field = test_case.fields.add()
+ field.name = "int32_value_with_default"
+ field.dtype = types_pb2.DT_INT32
+ field.value.int32_value.append(5)
+ test_case.sizes.append(0)
+ field = test_case.fields.add()
+ field.name = "sfixed32_value_with_default"
+ field.dtype = types_pb2.DT_INT32
+ field.value.int32_value.append(10)
+ test_case.sizes.append(0)
+ field = test_case.fields.add()
+ field.name = "sint32_value_with_default"
+ field.dtype = types_pb2.DT_INT32
+ field.value.int32_value.append(12)
+ test_case.sizes.append(0)
+ field = test_case.fields.add()
+ field.name = "uint32_value_with_default"
+ field.dtype = types_pb2.DT_INT32
+ field.value.int32_value.append(9)
+ test_case.sizes.append(0)
+ field = test_case.fields.add()
+ field.name = "fixed32_value_with_default"
+ field.dtype = types_pb2.DT_INT32
+ field.value.int32_value.append(7)
+ test_case.sizes.append(0)
+ field = test_case.fields.add()
+ field.name = "bool_value_with_default"
+ field.dtype = types_pb2.DT_BOOL
+ field.value.bool_value.append(True)
+ test_case.sizes.append(0)
+ field = test_case.fields.add()
+ field.name = "string_value_with_default"
+ field.dtype = types_pb2.DT_STRING
+ field.value.string_value.append("a")
+ test_case.sizes.append(0)
+ field = test_case.fields.add()
+ field.name = "bytes_value_with_default"
+ field.dtype = types_pb2.DT_STRING
+ field.value.string_value.append("a longer default string")
+ return test_case
+
+ @staticmethod
+ def minmax_test_case():
+ test_case = test_example_pb2.TestCase()
+ value = test_case.values.add()
+ value.double_value.append(-1.7976931348623158e+308)
+ value.double_value.append(2.2250738585072014e-308)
+ value.double_value.append(1.7976931348623158e+308)
+ value.float_value.append(-3.402823466e+38)
+ value.float_value.append(1.175494351e-38)
+ value.float_value.append(3.402823466e+38)
+ value.int64_value.append(-9223372036854775808)
+ value.int64_value.append(9223372036854775807)
+ value.sfixed64_value.append(-9223372036854775808)
+ value.sfixed64_value.append(9223372036854775807)
+ value.sint64_value.append(-9223372036854775808)
+ value.sint64_value.append(9223372036854775807)
+ value.uint64_value.append(0)
+ value.uint64_value.append(18446744073709551615)
+ value.fixed64_value.append(0)
+ value.fixed64_value.append(18446744073709551615)
+ value.int32_value.append(-2147483648)
+ value.int32_value.append(2147483647)
+ value.sfixed32_value.append(-2147483648)
+ value.sfixed32_value.append(2147483647)
+ value.sint32_value.append(-2147483648)
+ value.sint32_value.append(2147483647)
+ value.uint32_value.append(0)
+ value.uint32_value.append(4294967295)
+ value.fixed32_value.append(0)
+ value.fixed32_value.append(4294967295)
+ value.bool_value.append(False)
+ value.bool_value.append(True)
+ value.string_value.append("")
+ value.string_value.append("I refer to the infinite.")
+ test_case.shapes.append(1)
+ test_case.sizes.append(3)
+ field = test_case.fields.add()
+ field.name = "double_value"
+ field.dtype = types_pb2.DT_DOUBLE
+ field.value.double_value.append(-1.7976931348623158e+308)
+ field.value.double_value.append(2.2250738585072014e-308)
+ field.value.double_value.append(1.7976931348623158e+308)
+ test_case.sizes.append(3)
+ field = test_case.fields.add()
+ field.name = "float_value"
+ field.dtype = types_pb2.DT_FLOAT
+ field.value.float_value.append(-3.402823466e+38)
+ field.value.float_value.append(1.175494351e-38)
+ field.value.float_value.append(3.402823466e+38)
+ test_case.sizes.append(2)
+ field = test_case.fields.add()
+ field.name = "int64_value"
+ field.dtype = types_pb2.DT_INT64
+ field.value.int64_value.append(-9223372036854775808)
+ field.value.int64_value.append(9223372036854775807)
+ test_case.sizes.append(2)
+ field = test_case.fields.add()
+ field.name = "sfixed64_value"
+ field.dtype = types_pb2.DT_INT64
+ field.value.int64_value.append(-9223372036854775808)
+ field.value.int64_value.append(9223372036854775807)
+ test_case.sizes.append(2)
+ field = test_case.fields.add()
+ field.name = "sint64_value"
+ field.dtype = types_pb2.DT_INT64
+ field.value.int64_value.append(-9223372036854775808)
+ field.value.int64_value.append(9223372036854775807)
+ test_case.sizes.append(2)
+ field = test_case.fields.add()
+ field.name = "uint64_value"
+ field.dtype = types_pb2.DT_INT64
+ field.value.int64_value.append(0)
+ field.value.int64_value.append(-1)
+ test_case.sizes.append(2)
+ field = test_case.fields.add()
+ field.name = "fixed64_value"
+ field.dtype = types_pb2.DT_INT64
+ field.value.int64_value.append(0)
+ field.value.int64_value.append(-1)
+ test_case.sizes.append(2)
+ field = test_case.fields.add()
+ field.name = "int32_value"
+ field.dtype = types_pb2.DT_INT32
+ field.value.int32_value.append(-2147483648)
+ field.value.int32_value.append(2147483647)
+ test_case.sizes.append(2)
+ field = test_case.fields.add()
+ field.name = "sfixed32_value"
+ field.dtype = types_pb2.DT_INT32
+ field.value.int32_value.append(-2147483648)
+ field.value.int32_value.append(2147483647)
+ test_case.sizes.append(2)
+ field = test_case.fields.add()
+ field.name = "sint32_value"
+ field.dtype = types_pb2.DT_INT32
+ field.value.int32_value.append(-2147483648)
+ field.value.int32_value.append(2147483647)
+ test_case.sizes.append(2)
+ field = test_case.fields.add()
+ field.name = "uint32_value"
+ field.dtype = types_pb2.DT_INT32
+ field.value.int32_value.append(0)
+ field.value.int32_value.append(-1)
+ test_case.sizes.append(2)
+ field = test_case.fields.add()
+ field.name = "fixed32_value"
+ field.dtype = types_pb2.DT_INT32
+ field.value.int32_value.append(0)
+ field.value.int32_value.append(-1)
+ test_case.sizes.append(2)
+ field = test_case.fields.add()
+ field.name = "bool_value"
+ field.dtype = types_pb2.DT_BOOL
+ field.value.bool_value.append(False)
+ field.value.bool_value.append(True)
+ test_case.sizes.append(2)
+ field = test_case.fields.add()
+ field.name = "string_value"
+ field.dtype = types_pb2.DT_STRING
+ field.value.string_value.append("")
+ field.value.string_value.append("I refer to the infinite.")
+ return test_case
+
+ @staticmethod
+ def nested_test_case():
+ test_case = test_example_pb2.TestCase()
+ value = test_case.values.add()
+ message_value = value.message_value.add()
+ message_value.double_value = 23.5
+ test_case.shapes.append(1)
+ test_case.sizes.append(1)
+ field = test_case.fields.add()
+ field.name = "message_value"
+ field.dtype = types_pb2.DT_STRING
+ message_value = field.value.message_value.add()
+ message_value.double_value = 23.5
+ return test_case
+
+ @staticmethod
+ def optional_test_case():
+ test_case = test_example_pb2.TestCase()
+ value = test_case.values.add()
+ value.bool_value.append(True)
+ test_case.shapes.append(1)
+ test_case.sizes.append(1)
+ field = test_case.fields.add()
+ field.name = "bool_value"
+ field.dtype = types_pb2.DT_BOOL
+ field.value.bool_value.append(True)
+ test_case.sizes.append(0)
+ field = test_case.fields.add()
+ field.name = "double_value"
+ field.dtype = types_pb2.DT_DOUBLE
+ field.value.double_value.append(0.0)
+ return test_case
+
+ @staticmethod
+ def promote_unsigned_test_case():
+ test_case = test_example_pb2.TestCase()
+ value = test_case.values.add()
+ value.fixed32_value.append(4294967295)
+ value.uint32_value.append(4294967295)
+ test_case.shapes.append(1)
+ test_case.sizes.append(1)
+ field = test_case.fields.add()
+ field.name = "fixed32_value"
+ field.dtype = types_pb2.DT_INT64
+ field.value.int64_value.append(4294967295)
+ test_case.sizes.append(1)
+ field = test_case.fields.add()
+ field.name = "uint32_value"
+ field.dtype = types_pb2.DT_INT64
+ field.value.int64_value.append(4294967295)
+ # Comes from an explicitly-specified default
+ test_case.sizes.append(0)
+ field = test_case.fields.add()
+ field.name = "uint32_value_with_default"
+ field.dtype = types_pb2.DT_INT64
+ field.value.int64_value.append(9)
+ return test_case
+
+ @staticmethod
+ def ragged_test_case():
+ test_case = test_example_pb2.TestCase()
+ value = test_case.values.add()
+ value.double_value.append(23.5)
+ value.double_value.append(123.0)
+ value.bool_value.append(True)
+ value = test_case.values.add()
+ value.double_value.append(3.1)
+ value.bool_value.append(False)
+ test_case.shapes.append(2)
+ test_case.sizes.append(2)
+ test_case.sizes.append(1)
+ test_case.sizes.append(1)
+ test_case.sizes.append(1)
+ field = test_case.fields.add()
+ field.name = "double_value"
+ field.dtype = types_pb2.DT_DOUBLE
+ field.value.double_value.append(23.5)
+ field.value.double_value.append(123.0)
+ field.value.double_value.append(3.1)
+ field.value.double_value.append(0.0)
+ field = test_case.fields.add()
+ field.name = "bool_value"
+ field.dtype = types_pb2.DT_BOOL
+ field.value.bool_value.append(True)
+ field.value.bool_value.append(False)
+ return test_case
+
+ @staticmethod
+ def shaped_batch_test_case():
+ test_case = test_example_pb2.TestCase()
+ value = test_case.values.add()
+ value.double_value.append(23.5)
+ value.bool_value.append(True)
+ value = test_case.values.add()
+ value.double_value.append(44.0)
+ value.bool_value.append(False)
+ value = test_case.values.add()
+ value.double_value.append(3.14159)
+ value.bool_value.append(True)
+ value = test_case.values.add()
+ value.double_value.append(1.414)
+ value.bool_value.append(True)
+ value = test_case.values.add()
+ value.double_value.append(-32.2)
+ value.bool_value.append(False)
+ value = test_case.values.add()
+ value.double_value.append(0.0001)
+ value.bool_value.append(True)
+ test_case.shapes.append(3)
+ test_case.shapes.append(2)
+ for _ in range(12):
+ test_case.sizes.append(1)
+ field = test_case.fields.add()
+ field.name = "double_value"
+ field.dtype = types_pb2.DT_DOUBLE
+ field.value.double_value.append(23.5)
+ field.value.double_value.append(44.0)
+ field.value.double_value.append(3.14159)
+ field.value.double_value.append(1.414)
+ field.value.double_value.append(-32.2)
+ field.value.double_value.append(0.0001)
+ field = test_case.fields.add()
+ field.name = "bool_value"
+ field.dtype = types_pb2.DT_BOOL
+ field.value.bool_value.append(True)
+ field.value.bool_value.append(False)
+ field.value.bool_value.append(True)
+ field.value.bool_value.append(True)
+ field.value.bool_value.append(False)
+ field.value.bool_value.append(True)
+ return test_case
+
+ @staticmethod
+ def simple_test_case():
+ test_case = test_example_pb2.TestCase()
+ value = test_case.values.add()
+ value.double_value.append(23.5)
+ value.bool_value.append(True)
+ test_case.shapes.append(1)
+ test_case.sizes.append(1)
+ field = test_case.fields.add()
+ field.name = "double_value"
+ field.dtype = types_pb2.DT_DOUBLE
+ field.value.double_value.append(23.5)
+ test_case.sizes.append(1)
+ field = test_case.fields.add()
+ field.name = "bool_value"
+ field.dtype = types_pb2.DT_BOOL
+ field.value.bool_value.append(True)
+ return test_case
diff --git a/tensorflow/contrib/proto/python/kernel_tests/ragged.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/ragged.TestCase.pbtxt
deleted file mode 100644
index 61c7ac53f7..0000000000
--- a/tensorflow/contrib/proto/python/kernel_tests/ragged.TestCase.pbtxt
+++ /dev/null
@@ -1,32 +0,0 @@
-primitive {
- double_value: 23.5
- double_value: 123.0
- bool_value: true
-}
-primitive {
- double_value: 3.1
- bool_value: false
-}
-shape: 2
-sizes: 2
-sizes: 1
-sizes: 1
-sizes: 1
-field {
- name: "double_value"
- dtype: DT_DOUBLE
- expected {
- double_value: 23.5
- double_value: 123.0
- double_value: 3.1
- double_value: 0.0
- }
-}
-field {
- name: "bool_value"
- dtype: DT_BOOL
- expected {
- bool_value: true
- bool_value: false
- }
-}
diff --git a/tensorflow/contrib/proto/python/kernel_tests/shaped_batch.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/shaped_batch.TestCase.pbtxt
deleted file mode 100644
index f4828076d5..0000000000
--- a/tensorflow/contrib/proto/python/kernel_tests/shaped_batch.TestCase.pbtxt
+++ /dev/null
@@ -1,62 +0,0 @@
-primitive {
- double_value: 23.5
- bool_value: true
-}
-primitive {
- double_value: 44.0
- bool_value: false
-}
-primitive {
- double_value: 3.14159
- bool_value: true
-}
-primitive {
- double_value: 1.414
- bool_value: true
-}
-primitive {
- double_value: -32.2
- bool_value: false
-}
-primitive {
- double_value: 0.0001
- bool_value: true
-}
-shape: 3
-shape: 2
-sizes: 1
-sizes: 1
-sizes: 1
-sizes: 1
-sizes: 1
-sizes: 1
-sizes: 1
-sizes: 1
-sizes: 1
-sizes: 1
-sizes: 1
-sizes: 1
-field {
- name: "double_value"
- dtype: DT_DOUBLE
- expected {
- double_value: 23.5
- double_value: 44.0
- double_value: 3.14159
- double_value: 1.414
- double_value: -32.2
- double_value: 0.0001
- }
-}
-field {
- name: "bool_value"
- dtype: DT_BOOL
- expected {
- bool_value: true
- bool_value: false
- bool_value: true
- bool_value: true
- bool_value: false
- bool_value: true
- }
-}
diff --git a/tensorflow/contrib/proto/python/kernel_tests/simple.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/simple.TestCase.pbtxt
deleted file mode 100644
index dc20ac147b..0000000000
--- a/tensorflow/contrib/proto/python/kernel_tests/simple.TestCase.pbtxt
+++ /dev/null
@@ -1,21 +0,0 @@
-primitive {
- double_value: 23.5
- bool_value: true
-}
-shape: 1
-sizes: 1
-sizes: 1
-field {
- name: "double_value"
- dtype: DT_DOUBLE
- expected {
- double_value: 23.5
- }
-}
-field {
- name: "bool_value"
- dtype: DT_BOOL
- expected {
- bool_value: true
- }
-}
diff --git a/tensorflow/contrib/proto/python/kernel_tests/test_example.proto b/tensorflow/contrib/proto/python/kernel_tests/test_example.proto
index a2c88e372b..674d881220 100644
--- a/tensorflow/contrib/proto/python/kernel_tests/test_example.proto
+++ b/tensorflow/contrib/proto/python/kernel_tests/test_example.proto
@@ -1,6 +1,4 @@
// Test description and protos to work with it.
-//
-// Many of the protos in this file are for unit tests that haven't been written yet.
syntax = "proto2";
@@ -8,54 +6,27 @@ import "tensorflow/core/framework/types.proto";
package tensorflow.contrib.proto;
-// A TestCase holds a proto and a bunch of assertions
-// about how it should decode.
+// A TestCase holds a proto and assertions about how it should decode.
message TestCase {
- // A batch of primitives to be serialized and decoded.
- repeated RepeatedPrimitiveValue primitive = 1;
- // The shape of the batch.
- repeated int32 shape = 2;
+ // Batches of primitive values.
+ repeated TestValue values = 1;
+ // The batch shapes.
+ repeated int32 shapes = 2;
// Expected sizes for each field.
repeated int32 sizes = 3;
// Expected values for each field.
- repeated FieldSpec field = 4;
+ repeated FieldSpec fields = 4;
};
// FieldSpec describes the expected output for a single field.
message FieldSpec {
optional string name = 1;
optional tensorflow.DataType dtype = 2;
- optional RepeatedPrimitiveValue expected = 3;
+ optional TestValue value = 3;
};
+// NOTE: This definition must be kept in sync with PackedTestValue.
message TestValue {
- optional PrimitiveValue primitive_value = 1;
- optional EnumValue enum_value = 2;
- optional MessageValue message_value = 3;
- optional RepeatedMessageValue repeated_message_value = 4;
- optional RepeatedPrimitiveValue repeated_primitive_value = 6;
-}
-
-message PrimitiveValue {
- optional double double_value = 1;
- optional float float_value = 2;
- optional int64 int64_value = 3;
- optional uint64 uint64_value = 4;
- optional int32 int32_value = 5;
- optional fixed64 fixed64_value = 6;
- optional fixed32 fixed32_value = 7;
- optional bool bool_value = 8;
- optional string string_value = 9;
- optional bytes bytes_value = 12;
- optional uint32 uint32_value = 13;
- optional sfixed32 sfixed32_value = 15;
- optional sfixed64 sfixed64_value = 16;
- optional sint32 sint32_value = 17;
- optional sint64 sint64_value = 18;
-}
-
-// NOTE: This definition must be kept in sync with PackedPrimitiveValue.
-message RepeatedPrimitiveValue {
repeated double double_value = 1;
repeated float float_value = 2;
repeated int64 int64_value = 3;
@@ -74,30 +45,31 @@ message RepeatedPrimitiveValue {
repeated PrimitiveValue message_value = 19;
// Optional fields with explicitly-specified defaults.
- optional double double_default = 20 [default = 1.0];
- optional float float_default = 21 [default = 2.0];
- optional int64 int64_default = 22 [default = 3];
- optional uint64 uint64_default = 23 [default = 4];
- optional int32 int32_default = 24 [default = 5];
- optional fixed64 fixed64_default = 25 [default = 6];
- optional fixed32 fixed32_default = 26 [default = 7];
- optional bool bool_default = 27 [default = true];
- optional string string_default = 28 [default = "a"];
- optional bytes bytes_default = 29 [default = "a longer default string"];
- optional uint32 uint32_default = 30 [default = 4294967295];
- optional sfixed32 sfixed32_default = 31 [default = 10];
- optional sfixed64 sfixed64_default = 32 [default = 11];
- optional sint32 sint32_default = 33 [default = 12];
- optional sint64 sint64_default = 34 [default = 13];
+ optional double double_value_with_default = 20 [default = 1.0];
+ optional float float_value_with_default = 21 [default = 2.0];
+ optional int64 int64_value_with_default = 22 [default = 3];
+ optional uint64 uint64_value_with_default = 23 [default = 4];
+ optional int32 int32_value_with_default = 24 [default = 5];
+ optional fixed64 fixed64_value_with_default = 25 [default = 6];
+ optional fixed32 fixed32_value_with_default = 26 [default = 7];
+ optional bool bool_value_with_default = 27 [default = true];
+ optional string string_value_with_default = 28 [default = "a"];
+ optional bytes bytes_value_with_default = 29
+ [default = "a longer default string"];
+ optional uint32 uint32_value_with_default = 30 [default = 9];
+ optional sfixed32 sfixed32_value_with_default = 31 [default = 10];
+ optional sfixed64 sfixed64_value_with_default = 32 [default = 11];
+ optional sint32 sint32_value_with_default = 33 [default = 12];
+ optional sint64 sint64_value_with_default = 34 [default = 13];
}
-// A PackedPrimitiveValue looks exactly the same as a RepeatedPrimitiveValue
-// in the text format, but the binary serializion is different.
-// We test the packed representations by loading the same test cases
-// using this definition instead of RepeatedPrimitiveValue.
-// NOTE: This definition must be kept in sync with RepeatedPrimitiveValue
-// in every way except the packed=true declaration.
-message PackedPrimitiveValue {
+// A PackedTestValue looks exactly the same as a TestValue in the text format,
+// but the binary serializion is different. We test the packed representations
+// by loading the same test cases using this definition instead of TestValue.
+//
+// NOTE: This definition must be kept in sync with TestValue in every way except
+// the packed=true declaration.
+message PackedTestValue {
repeated double double_value = 1 [packed = true];
repeated float float_value = 2 [packed = true];
repeated int64 int64_value = 3 [packed = true];
@@ -115,23 +87,53 @@ message PackedPrimitiveValue {
repeated sint64 sint64_value = 18 [packed = true];
repeated PrimitiveValue message_value = 19;
- optional double double_default = 20 [default = 1.0];
- optional float float_default = 21 [default = 2.0];
- optional int64 int64_default = 22 [default = 3];
- optional uint64 uint64_default = 23 [default = 4];
- optional int32 int32_default = 24 [default = 5];
- optional fixed64 fixed64_default = 25 [default = 6];
- optional fixed32 fixed32_default = 26 [default = 7];
- optional bool bool_default = 27 [default = true];
- optional string string_default = 28 [default = "a"];
- optional bytes bytes_default = 29 [default = "a longer default string"];
- optional uint32 uint32_default = 30 [default = 4294967295];
- optional sfixed32 sfixed32_default = 31 [default = 10];
- optional sfixed64 sfixed64_default = 32 [default = 11];
- optional sint32 sint32_default = 33 [default = 12];
- optional sint64 sint64_default = 34 [default = 13];
+ optional double double_value_with_default = 20 [default = 1.0];
+ optional float float_value_with_default = 21 [default = 2.0];
+ optional int64 int64_value_with_default = 22 [default = 3];
+ optional uint64 uint64_value_with_default = 23 [default = 4];
+ optional int32 int32_value_with_default = 24 [default = 5];
+ optional fixed64 fixed64_value_with_default = 25 [default = 6];
+ optional fixed32 fixed32_value_with_default = 26 [default = 7];
+ optional bool bool_value_with_default = 27 [default = true];
+ optional string string_value_with_default = 28 [default = "a"];
+ optional bytes bytes_value_with_default = 29
+ [default = "a longer default string"];
+ optional uint32 uint32_value_with_default = 30 [default = 9];
+ optional sfixed32 sfixed32_value_with_default = 31 [default = 10];
+ optional sfixed64 sfixed64_value_with_default = 32 [default = 11];
+ optional sint32 sint32_value_with_default = 33 [default = 12];
+ optional sint64 sint64_value_with_default = 34 [default = 13];
}
+message PrimitiveValue {
+ optional double double_value = 1;
+ optional float float_value = 2;
+ optional int64 int64_value = 3;
+ optional uint64 uint64_value = 4;
+ optional int32 int32_value = 5;
+ optional fixed64 fixed64_value = 6;
+ optional fixed32 fixed32_value = 7;
+ optional bool bool_value = 8;
+ optional string string_value = 9;
+ optional bytes bytes_value = 12;
+ optional uint32 uint32_value = 13;
+ optional sfixed32 sfixed32_value = 15;
+ optional sfixed64 sfixed64_value = 16;
+ optional sint32 sint32_value = 17;
+ optional sint64 sint64_value = 18;
+}
+
+// Message containing fields with field numbers higher than any field above.
+// An instance of this message is prepended to each binary message in the test
+// to exercise the code path that handles fields encoded out of order of field
+// number.
+message ExtraFields {
+ optional string string_value = 1776;
+ optional bool bool_value = 1777;
+}
+
+// The messages below are for yet-to-be created tests.
+
message EnumValue {
enum Color {
RED = 0;
@@ -171,12 +173,3 @@ message RepeatedMessageValue {
repeated NestedMessageValue message_values = 11;
}
-
-// Message containing fields with field numbers higher than any field above. An
-// instance of this message is prepended to each binary message in the test to
-// exercise the code path that handles fields encoded out of order of field
-// number.
-message ExtraFields {
- optional string string_value = 1776;
- optional bool bool_value = 1777;
-}
diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py
index 19e5bef1ea..4fc315d901 100644
--- a/tensorflow/contrib/quantize/python/quantize.py
+++ b/tensorflow/contrib/quantize/python/quantize.py
@@ -278,6 +278,13 @@ def _FindLayersToQuantize(graph):
],
ordered_inputs=False)
+ # batch_norms with forced updates have an Identity operation at the end.
+ # TODO(suharshs): Find a way to easily skip extra Identity operations. The
+ # current issue is that doing so can often match patterns across many layers
+ # incorrectly.
+ batch_norm_identity = graph_matcher.OpTypePattern(
+ 'Identity', inputs=[folded_bias_add_pattern])
+
bias_add_pattern = graph_matcher.OpTypePattern(
'Add|BiasAdd', inputs=[layer_output_pattern, '*'], ordered_inputs=False)
@@ -286,20 +293,22 @@ def _FindLayersToQuantize(graph):
'Add',
inputs=[
graph_matcher.OneofPattern(
- [bias_add_pattern, folded_bias_add_pattern]), '*'
+ [bias_add_pattern, folded_bias_add_pattern, batch_norm_identity]),
+ '*'
],
ordered_inputs=False)
# The input to the activation can come from bias add, fold bias add, the
# bypasses.
# TODO(suharshs): We should ideally skip Identity operations instead of
- # treating them as an activation.
+ # treating them as activations.
activation_pattern = graph_matcher.OpTypePattern(
'|'.join(_ACTIVATION_TYPES) + '|Identity',
inputs=[
graph_matcher.OneofPattern([
bias_add_pattern,
folded_bias_add_pattern,
+ batch_norm_identity,
bypass_pattern,
])
])
diff --git a/tensorflow/contrib/quantize/python/quantize_graph.py b/tensorflow/contrib/quantize/python/quantize_graph.py
index 11d052d7f4..2944f964c7 100644
--- a/tensorflow/contrib/quantize/python/quantize_graph.py
+++ b/tensorflow/contrib/quantize/python/quantize_graph.py
@@ -191,6 +191,7 @@ def experimental_create_training_graph(input_graph=None,
def experimental_create_eval_graph(input_graph=None,
weight_bits=8,
activation_bits=8,
+ quant_delay=None,
scope=None):
"""Rewrites an eval input_graph in place for simulated quantization.
@@ -209,6 +210,8 @@ def experimental_create_eval_graph(input_graph=None,
default graph.
weight_bits: Number of bits to use for quantizing weights.
activation_bits: Number of bits to use for quantizing activations.
+ quant_delay: Number of steps after which weights and activations are
+ quantized during eval.
scope: The scope to be transformed. If it's not None, only the ops which
are in this scope will be transformed.
@@ -221,4 +224,5 @@ def experimental_create_eval_graph(input_graph=None,
is_training=False,
weight_bits=weight_bits,
activation_bits=activation_bits,
+ quant_delay=quant_delay,
scope=scope)
diff --git a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py
index 5e3af0a567..31a2955ddb 100644
--- a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py
+++ b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py
@@ -654,8 +654,80 @@ class QuantizeTest(test_util.TensorFlowTestCase):
graph_def_after = str(graph.as_graph_def())
self.assertEqual(graph_def_before, graph_def_after)
- def _BatchNormParams(self, fused=False):
- return {'center': True, 'scale': True, 'decay': 1.0 - 0.003, 'fused': fused}
+ def testBatchNormForcedUpdates(self):
+ parameter_list = [
+ # (activation, activation_op_name, fused_batch_norm)
+ (nn_ops.relu6, 'Relu6', False),
+ (nn_ops.relu, 'Relu', False),
+ (array_ops.identity, 'Identity', False),
+ (nn_ops.relu6, 'Relu6', True),
+ (nn_ops.relu, 'Relu', True),
+ (array_ops.identity, 'Identity', True),
+ ]
+ for params in parameter_list:
+ self._TestBatchNormForcedUpdates(params[0], params[1], params[2], False)
+ self._TestBatchNormForcedUpdates(params[0], params[1], params[2], True)
+
+ def _TestBatchNormForcedUpdates(self, activation, activation_op_name,
+ fused_batch_norm, use_resource):
+ """post_activation bypass quantization should happen with forced updates."""
+ graph = ops.Graph()
+ with graph.as_default():
+ variable_scope.get_variable_scope().set_use_resource(use_resource)
+ batch_size, height, width, depth = 5, 128, 128, 3
+ input1 = array_ops.zeros((batch_size, height, width, depth))
+ input2 = array_ops.zeros((batch_size, height / 2, width / 2, 32))
+ # Setting updates_collections to None forces updates adding an extra
+ # identity operation following batch norms.
+ bn_params = self._BatchNormParams(
+ fused=fused_batch_norm, force_updates=True)
+ conv = conv2d(
+ input1,
+ 32, [5, 5],
+ stride=2,
+ padding='SAME',
+ weights_initializer=self._WeightInit(0.09),
+ activation_fn=activation,
+ normalizer_fn=batch_norm,
+ normalizer_params=bn_params,
+ scope='test/test')
+ bypass_tensor = math_ops.add(conv, input2, name='test/add')
+ # The output of the post_activation bypass will be another layer.
+ _ = conv2d(
+ bypass_tensor,
+ 32, [5, 5],
+ stride=2,
+ padding='SAME',
+ weights_initializer=self._WeightInit(0.09),
+ normalizer_fn=batch_norm,
+ normalizer_params=bn_params,
+ activation_fn=activation,
+ scope='test/unused')
+
+ fold_batch_norms.FoldBatchNorms(graph, is_training=True)
+ quantize.Quantize(graph, is_training=True)
+
+ # Ensure that the bypass node is preceded by and followed by a
+ # FakeQuantWithMinMaxVar operation, since the output of the Add isn't an
+ # activation.
+ self.assertTrue('FakeQuantWithMinMaxVars' in
+ [c.type for c in bypass_tensor.consumers()])
+ self.assertTrue('FakeQuantWithMinMaxVars' in
+ [i.op.type for i in bypass_tensor.op.inputs])
+
+ with open('/tmp/bn_quant_test.pbtxt', 'w') as f:
+ f.write(str(graph.as_graph_def()))
+
+ def _BatchNormParams(self, fused=False, force_updates=False):
+ params = {
+ 'center': True,
+ 'scale': True,
+ 'decay': 1.0 - 0.003,
+ 'fused': fused
+ }
+ if force_updates:
+ params['updates_collections'] = None
+ return params
def _WeightInit(self, stddev):
"""Returns truncated normal variable initializer.
diff --git a/tensorflow/contrib/rnn/BUILD b/tensorflow/contrib/rnn/BUILD
index 4eb5c920b3..2a84629080 100644
--- a/tensorflow/contrib/rnn/BUILD
+++ b/tensorflow/contrib/rnn/BUILD
@@ -118,7 +118,6 @@ cuda_py_tests(
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:init_ops",
"//tensorflow/python:math_ops",
- "//tensorflow/python:random_ops",
"//tensorflow/python:rnn",
"//tensorflow/python:rnn_cell",
"//tensorflow/python:variable_scope",
diff --git a/tensorflow/contrib/rnn/__init__.py b/tensorflow/contrib/rnn/__init__.py
index 07227bcb77..cb437f2a2f 100644
--- a/tensorflow/contrib/rnn/__init__.py
+++ b/tensorflow/contrib/rnn/__init__.py
@@ -59,6 +59,9 @@ See @{$python/contrib.rnn} guide.
@@HighwayWrapper
@@GLSTMCell
@@SRUCell
+@@IndRNNCell
+@@IndyGRUCell
+@@IndyLSTMCell
<!--RNNCell wrappers-->
@@AttentionCellWrapper
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
index 86f1e27abd..85f0f8ced9 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
@@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import functools
import os
import numpy as np
@@ -35,7 +34,6 @@ from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import random_ops
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import variable_scope
@@ -117,6 +115,27 @@ class RNNCellTest(test.TestCase):
})
self.assertEqual(res[0].shape, (1, 2))
+ def testIndRNNCell(self):
+ with self.test_session() as sess:
+ with variable_scope.variable_scope(
+ "root", initializer=init_ops.constant_initializer(0.5)):
+ x = array_ops.zeros([1, 2])
+ m = array_ops.zeros([1, 2])
+ cell = contrib_rnn_cell.IndRNNCell(2)
+ g, _ = cell(x, m)
+ self.assertEqual([
+ "root/ind_rnn_cell/%s_w:0" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ "root/ind_rnn_cell/%s_u:0" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ "root/ind_rnn_cell/%s:0" % rnn_cell_impl._BIAS_VARIABLE_NAME
+ ], [v.name for v in cell.trainable_variables])
+ self.assertFalse(cell.non_trainable_variables)
+ sess.run([variables_lib.global_variables_initializer()])
+ res = sess.run([g], {
+ x.name: np.array([[1., 1.]]),
+ m.name: np.array([[0.1, 0.1]])
+ })
+ self.assertEqual(res[0].shape, (1, 2))
+
def testGRUCell(self):
with self.test_session() as sess:
with variable_scope.variable_scope(
@@ -145,6 +164,34 @@ class RNNCellTest(test.TestCase):
# Smoke test
self.assertAllClose(res[0], [[0.156736, 0.156736]])
+ def testIndyGRUCell(self):
+ with self.test_session() as sess:
+ with variable_scope.variable_scope(
+ "root", initializer=init_ops.constant_initializer(0.5)):
+ x = array_ops.zeros([1, 2])
+ m = array_ops.zeros([1, 2])
+ g, _ = contrib_rnn_cell.IndyGRUCell(2)(x, m)
+ sess.run([variables_lib.global_variables_initializer()])
+ res = sess.run([g], {
+ x.name: np.array([[1., 1.]]),
+ m.name: np.array([[0.1, 0.1]])
+ })
+ # Smoke test
+ self.assertAllClose(res[0], [[0.185265, 0.17704]])
+ with variable_scope.variable_scope(
+ "other", initializer=init_ops.constant_initializer(0.5)):
+ # Test IndyGRUCell with input_size != num_units.
+ x = array_ops.zeros([1, 3])
+ m = array_ops.zeros([1, 2])
+ g, _ = contrib_rnn_cell.IndyGRUCell(2)(x, m)
+ sess.run([variables_lib.global_variables_initializer()])
+ res = sess.run([g], {
+ x.name: np.array([[1., 1., 1.]]),
+ m.name: np.array([[0.1, 0.1]])
+ })
+ # Smoke test
+ self.assertAllClose(res[0], [[0.155127, 0.157328]])
+
def testSRUCell(self):
with self.test_session() as sess:
with variable_scope.variable_scope(
@@ -345,6 +392,72 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[1], expected_mem0)
self.assertAllClose(res[2], expected_mem1)
+ def testIndyLSTMCell(self):
+ for dtype in [dtypes.float16, dtypes.float32]:
+ np_dtype = dtype.as_numpy_dtype
+ with self.test_session(graph=ops.Graph()) as sess:
+ with variable_scope.variable_scope(
+ "root", initializer=init_ops.constant_initializer(0.5)):
+ x = array_ops.zeros([1, 2], dtype=dtype)
+ state_0 = (array_ops.zeros([1, 2], dtype=dtype),) * 2
+ state_1 = (array_ops.zeros([1, 2], dtype=dtype),) * 2
+ cell = rnn_cell_impl.MultiRNNCell(
+ [contrib_rnn_cell.IndyLSTMCell(2) for _ in range(2)])
+ self.assertEqual(cell.dtype, None)
+ self.assertEqual("cell-0", cell._checkpoint_dependencies[0].name)
+ self.assertEqual("cell-1", cell._checkpoint_dependencies[1].name)
+ cell.get_config() # Should not throw an error
+ g, (out_state_0, out_state_1) = cell(x, (state_0, state_1))
+ # Layer infers the input type.
+ self.assertEqual(cell.dtype, dtype.name)
+ expected_variable_names = [
+ "root/multi_rnn_cell/cell_0/indy_lstm_cell/%s_w:0" %
+ rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ "root/multi_rnn_cell/cell_0/indy_lstm_cell/%s_u:0" %
+ rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ "root/multi_rnn_cell/cell_0/indy_lstm_cell/%s:0" %
+ rnn_cell_impl._BIAS_VARIABLE_NAME,
+ "root/multi_rnn_cell/cell_1/indy_lstm_cell/%s_w:0" %
+ rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ "root/multi_rnn_cell/cell_1/indy_lstm_cell/%s_u:0" %
+ rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ "root/multi_rnn_cell/cell_1/indy_lstm_cell/%s:0" %
+ rnn_cell_impl._BIAS_VARIABLE_NAME
+ ]
+ self.assertEqual(expected_variable_names,
+ [v.name for v in cell.trainable_variables])
+ self.assertFalse(cell.non_trainable_variables)
+ sess.run([variables_lib.global_variables_initializer()])
+ res = sess.run(
+ [g, out_state_0, out_state_1], {
+ x.name: np.array([[1., 1.]]),
+ state_0[0].name: 0.1 * np.ones([1, 2]),
+ state_0[1].name: 0.1 * np.ones([1, 2]),
+ state_1[0].name: 0.1 * np.ones([1, 2]),
+ state_1[1].name: 0.1 * np.ones([1, 2]),
+ })
+ self.assertEqual(len(res), 3)
+ variables = variables_lib.global_variables()
+ self.assertEqual(expected_variable_names, [v.name for v in variables])
+ # Only check the range of outputs as this is just a smoke test.
+ self.assertAllInRange(res[0], -1.0, 1.0)
+ self.assertAllInRange(res[1], -1.0, 1.0)
+ self.assertAllInRange(res[2], -1.0, 1.0)
+ with variable_scope.variable_scope(
+ "other", initializer=init_ops.constant_initializer(0.5)):
+ # Test IndyLSTMCell with input_size != num_units.
+ x = array_ops.zeros([1, 3], dtype=dtype)
+ state = (array_ops.zeros([1, 2], dtype=dtype),) * 2
+ g, out_state = contrib_rnn_cell.IndyLSTMCell(2)(x, state)
+ sess.run([variables_lib.global_variables_initializer()])
+ res = sess.run(
+ [g, out_state], {
+ x.name: np.array([[1., 1., 1.]], dtype=np_dtype),
+ state[0].name: 0.1 * np.ones([1, 2], dtype=np_dtype),
+ state[1].name: 0.1 * np.ones([1, 2], dtype=np_dtype),
+ })
+ self.assertEqual(len(res), 2)
+
def testLSTMCell(self):
with self.test_session() as sess:
num_units = 8
@@ -935,50 +1048,6 @@ class DropoutWrapperTest(test.TestCase):
self.assertAllClose(res0[1].h, res1[1].h)
-class SlimRNNCellTest(test.TestCase):
-
- def testBasicRNNCell(self):
- with self.test_session() as sess:
- with variable_scope.variable_scope(
- "root", initializer=init_ops.constant_initializer(0.5)):
- x = array_ops.zeros([1, 2])
- m = array_ops.zeros([1, 2])
- my_cell = functools.partial(basic_rnn_cell, num_units=2)
- # pylint: disable=protected-access
- g, _ = rnn_cell_impl._SlimRNNCell(my_cell)(x, m)
- # pylint: enable=protected-access
- sess.run([variables_lib.global_variables_initializer()])
- res = sess.run([g], {
- x.name: np.array([[1., 1.]]),
- m.name: np.array([[0.1, 0.1]])
- })
- self.assertEqual(res[0].shape, (1, 2))
-
- def testBasicRNNCellMatch(self):
- batch_size = 32
- input_size = 100
- num_units = 10
- with self.test_session() as sess:
- with variable_scope.variable_scope(
- "root", initializer=init_ops.constant_initializer(0.5)):
- inputs = random_ops.random_uniform((batch_size, input_size))
- _, initial_state = basic_rnn_cell(inputs, None, num_units)
- rnn_cell = rnn_cell_impl.BasicRNNCell(num_units)
- outputs, state = rnn_cell(inputs, initial_state)
- variable_scope.get_variable_scope().reuse_variables()
- my_cell = functools.partial(basic_rnn_cell, num_units=num_units)
- # pylint: disable=protected-access
- slim_cell = rnn_cell_impl._SlimRNNCell(my_cell)
- # pylint: enable=protected-access
- slim_outputs, slim_state = slim_cell(inputs, initial_state)
- self.assertEqual(slim_outputs.get_shape(), outputs.get_shape())
- self.assertEqual(slim_state.get_shape(), state.get_shape())
- sess.run([variables_lib.global_variables_initializer()])
- res = sess.run([slim_outputs, slim_state, outputs, state])
- self.assertAllClose(res[0], res[2])
- self.assertAllClose(res[1], res[3])
-
-
def basic_rnn_cell(inputs, state, num_units, scope=None):
if state is None:
if inputs is not None:
diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
index b12e2cd5ed..1816b469ee 100644
--- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py
+++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
@@ -23,6 +23,7 @@ import math
from tensorflow.contrib.compiler import jit
from tensorflow.contrib.layers.python.layers import layers
from tensorflow.contrib.rnn.python.ops import core_rnn_cell
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import op_def_registry
from tensorflow.python.framework import ops
@@ -30,6 +31,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.layers import base as base_layer
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
+from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_impl # pylint: disable=unused-import
@@ -3050,3 +3052,343 @@ class WeightNormLSTMCell(rnn_cell_impl.RNNCell):
new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h)
return new_h, new_state
+
+
+class IndRNNCell(rnn_cell_impl.LayerRNNCell):
+ """Independently Recurrent Neural Network (IndRNN) cell
+ (cf. https://arxiv.org/abs/1803.04831).
+
+ Args:
+ num_units: int, The number of units in the RNN cell.
+ activation: Nonlinearity to use. Default: `tanh`.
+ reuse: (optional) Python boolean describing whether to reuse variables
+ in an existing scope. If not `True`, and the existing scope already has
+ the given variables, an error is raised.
+ name: String, the name of the layer. Layers with the same name will
+ share weights, but to avoid mistakes we require reuse=True in such
+ cases.
+ dtype: Default dtype of the layer (default of `None` means use the type
+ of the first input). Required when `build` is called before `call`.
+ """
+
+ def __init__(self,
+ num_units,
+ activation=None,
+ reuse=None,
+ name=None,
+ dtype=None):
+ super(IndRNNCell, self).__init__(_reuse=reuse, name=name, dtype=dtype)
+
+ # Inputs must be 2-dimensional.
+ self.input_spec = base_layer.InputSpec(ndim=2)
+
+ self._num_units = num_units
+ self._activation = activation or math_ops.tanh
+
+ @property
+ def state_size(self):
+ return self._num_units
+
+ @property
+ def output_size(self):
+ return self._num_units
+
+ def build(self, inputs_shape):
+ if inputs_shape[1].value is None:
+ raise ValueError(
+ "Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape)
+
+ input_depth = inputs_shape[1].value
+ # pylint: disable=protected-access
+ self._kernel_w = self.add_variable(
+ "%s_w" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ shape=[input_depth, self._num_units])
+ self._kernel_u = self.add_variable(
+ "%s_u" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ shape=[1, self._num_units],
+ initializer=init_ops.random_uniform_initializer(
+ minval=-1, maxval=1, dtype=self.dtype))
+ self._bias = self.add_variable(
+ rnn_cell_impl._BIAS_VARIABLE_NAME,
+ shape=[self._num_units],
+ initializer=init_ops.zeros_initializer(dtype=self.dtype))
+ # pylint: enable=protected-access
+
+ self.built = True
+
+ def call(self, inputs, state):
+ """IndRNN: output = new_state = act(W * input + u * state + B)."""
+
+ gate_inputs = math_ops.matmul(inputs, self._kernel_w) + (
+ state * self._kernel_u)
+ gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)
+ output = self._activation(gate_inputs)
+ return output, output
+
+
+class IndyGRUCell(rnn_cell_impl.LayerRNNCell):
+ r"""Independently Gated Recurrent Unit cell.
+
+ Based on IndRNNs (https://arxiv.org/abs/1803.04831) and similar to GRUCell,
+ yet with the \(U_r\), \(U_z\), and \(U\) matrices in equations 5, 6, and
+ 8 of http://arxiv.org/abs/1406.1078 respectively replaced by diagonal
+ matrices, i.e. a Hadamard product with a single vector:
+
+ $$r_j = \sigma\left([\mathbf W_r\mathbf x]_j +
+ [\mathbf u_r\circ \mathbf h_{(t-1)}]_j\right)$$
+ $$z_j = \sigma\left([\mathbf W_z\mathbf x]_j +
+ [\mathbf u_z\circ \mathbf h_{(t-1)}]_j\right)$$
+ $$\tilde{h}^{(t)}_j = \phi\left([\mathbf W \mathbf x]_j +
+ [\mathbf u \circ \mathbf r \circ \mathbf h_{(t-1)}]_j\right)$$
+
+ where \(\circ\) denotes the Hadamard operator. This means that each IndyGRU
+ node sees only its own state, as opposed to seeing all states in the same
+ layer.
+
+ TODO(gonnet): Write a paper describing this and add a reference here.
+
+ Args:
+ num_units: int, The number of units in the GRU cell.
+ activation: Nonlinearity to use. Default: `tanh`.
+ reuse: (optional) Python boolean describing whether to reuse variables
+ in an existing scope. If not `True`, and the existing scope already has
+ the given variables, an error is raised.
+ kernel_initializer: (optional) The initializer to use for the weight
+ matrices applied to the input.
+ bias_initializer: (optional) The initializer to use for the bias.
+ name: String, the name of the layer. Layers with the same name will
+ share weights, but to avoid mistakes we require reuse=True in such
+ cases.
+ dtype: Default dtype of the layer (default of `None` means use the type
+ of the first input). Required when `build` is called before `call`.
+ """
+
+ def __init__(self,
+ num_units,
+ activation=None,
+ reuse=None,
+ kernel_initializer=None,
+ bias_initializer=None,
+ name=None,
+ dtype=None):
+ super(IndyGRUCell, self).__init__(_reuse=reuse, name=name, dtype=dtype)
+
+ # Inputs must be 2-dimensional.
+ self.input_spec = base_layer.InputSpec(ndim=2)
+
+ self._num_units = num_units
+ self._activation = activation or math_ops.tanh
+ self._kernel_initializer = kernel_initializer
+ self._bias_initializer = bias_initializer
+
+ @property
+ def state_size(self):
+ return self._num_units
+
+ @property
+ def output_size(self):
+ return self._num_units
+
+ def build(self, inputs_shape):
+ if inputs_shape[1].value is None:
+ raise ValueError(
+ "Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape)
+
+ input_depth = inputs_shape[1].value
+ # pylint: disable=protected-access
+ self._gate_kernel_w = self.add_variable(
+ "gates/%s_w" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ shape=[input_depth, 2 * self._num_units],
+ initializer=self._kernel_initializer)
+ self._gate_kernel_u = self.add_variable(
+ "gates/%s_u" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ shape=[1, 2 * self._num_units],
+ initializer=init_ops.random_uniform_initializer(
+ minval=-1, maxval=1, dtype=self.dtype))
+ self._gate_bias = self.add_variable(
+ "gates/%s" % rnn_cell_impl._BIAS_VARIABLE_NAME,
+ shape=[2 * self._num_units],
+ initializer=(self._bias_initializer
+ if self._bias_initializer is not None else
+ init_ops.constant_initializer(1.0, dtype=self.dtype)))
+ self._candidate_kernel_w = self.add_variable(
+ "candidate/%s" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ shape=[input_depth, self._num_units],
+ initializer=self._kernel_initializer)
+ self._candidate_kernel_u = self.add_variable(
+ "candidate/%s_u" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ shape=[1, self._num_units],
+ initializer=init_ops.random_uniform_initializer(
+ minval=-1, maxval=1, dtype=self.dtype))
+ self._candidate_bias = self.add_variable(
+ "candidate/%s" % rnn_cell_impl._BIAS_VARIABLE_NAME,
+ shape=[self._num_units],
+ initializer=(self._bias_initializer
+ if self._bias_initializer is not None else
+ init_ops.zeros_initializer(dtype=self.dtype)))
+ # pylint: enable=protected-access
+
+ self.built = True
+
+ def call(self, inputs, state):
+ """Gated recurrent unit (GRU) with nunits cells."""
+
+ gate_inputs = math_ops.matmul(inputs, self._gate_kernel_w) + (
+ gen_array_ops.tile(state, [1, 2]) * self._gate_kernel_u)
+ gate_inputs = nn_ops.bias_add(gate_inputs, self._gate_bias)
+
+ value = math_ops.sigmoid(gate_inputs)
+ r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1)
+
+ r_state = r * state
+
+ candidate = math_ops.matmul(inputs, self._candidate_kernel_w) + (
+ r_state * self._candidate_kernel_u)
+ candidate = nn_ops.bias_add(candidate, self._candidate_bias)
+
+ c = self._activation(candidate)
+ new_h = u * state + (1 - u) * c
+ return new_h, new_h
+
+
+class IndyLSTMCell(rnn_cell_impl.LayerRNNCell):
+ r"""Basic IndyLSTM recurrent network cell.
+
+ Based on IndRNNs (https://arxiv.org/abs/1803.04831) and similar to
+ BasicLSTMCell, yet with the \(U_f\), \(U_i\), \(U_o\) and \(U_c\)
+ matrices in
+ https://en.wikipedia.org/wiki/Long_short-term_memory#LSTM_with_a_forget_gate
+ replaced by diagonal matrices, i.e. a Hadamard product with a single vector:
+
+ $$f_t = \sigma_g\left(W_f x_t + u_f \circ h_{t-1} + b_f\right)$$
+ $$i_t = \sigma_g\left(W_i x_t + u_i \circ h_{t-1} + b_i\right)$$
+ $$o_t = \sigma_g\left(W_o x_t + u_o \circ h_{t-1} + b_o\right)$$
+ $$c_t = f_t \circ c_{t-1} +
+ i_t \circ \sigma_c\left(W_c x_t + u_c \circ h_{t-1} + b_c\right)$$
+
+ where \(\circ\) denotes the Hadamard operator. This means that each IndyLSTM
+ node sees only its own state \(h\) and \(c\), as opposed to seeing all
+ states in the same layer.
+
+ We add forget_bias (default: 1) to the biases of the forget gate in order to
+ reduce the scale of forgetting in the beginning of the training.
+
+ It does not allow cell clipping, a projection layer, and does not
+ use peep-hole connections: it is the basic baseline.
+
+ For advanced models, please use the full @{tf.nn.rnn_cell.LSTMCell}
+ that follows.
+
+ TODO(gonnet): Write a paper describing this and add a reference here.
+ """
+
+ def __init__(self,
+ num_units,
+ forget_bias=1.0,
+ activation=None,
+ reuse=None,
+ kernel_initializer=None,
+ bias_initializer=None,
+ name=None,
+ dtype=None):
+ """Initialize the IndyLSTM cell.
+
+ Args:
+ num_units: int, The number of units in the LSTM cell.
+ forget_bias: float, The bias added to forget gates (see above).
+ Must set to `0.0` manually when restoring from CudnnLSTM-trained
+ checkpoints.
+ activation: Activation function of the inner states. Default: `tanh`.
+ reuse: (optional) Python boolean describing whether to reuse variables
+ in an existing scope. If not `True`, and the existing scope already has
+ the given variables, an error is raised.
+ kernel_initializer: (optional) The initializer to use for the weight
+ matrix applied to the inputs.
+ bias_initializer: (optional) The initializer to use for the bias.
+ name: String, the name of the layer. Layers with the same name will
+ share weights, but to avoid mistakes we require reuse=True in such
+ cases.
+ dtype: Default dtype of the layer (default of `None` means use the type
+ of the first input). Required when `build` is called before `call`.
+ """
+ super(IndyLSTMCell, self).__init__(_reuse=reuse, name=name, dtype=dtype)
+
+ # Inputs must be 2-dimensional.
+ self.input_spec = base_layer.InputSpec(ndim=2)
+
+ self._num_units = num_units
+ self._forget_bias = forget_bias
+ self._activation = activation or math_ops.tanh
+ self._kernel_initializer = kernel_initializer
+ self._bias_initializer = bias_initializer
+
+ @property
+ def state_size(self):
+ return rnn_cell_impl.LSTMStateTuple(self._num_units, self._num_units)
+
+ @property
+ def output_size(self):
+ return self._num_units
+
+ def build(self, inputs_shape):
+ if inputs_shape[1].value is None:
+ raise ValueError(
+ "Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape)
+
+ input_depth = inputs_shape[1].value
+ # pylint: disable=protected-access
+ self._kernel_w = self.add_variable(
+ "%s_w" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ shape=[input_depth, 4 * self._num_units],
+ initializer=self._kernel_initializer)
+ self._kernel_u = self.add_variable(
+ "%s_u" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ shape=[1, 4 * self._num_units],
+ initializer=init_ops.random_uniform_initializer(
+ minval=-1, maxval=1, dtype=self.dtype))
+ self._bias = self.add_variable(
+ rnn_cell_impl._BIAS_VARIABLE_NAME,
+ shape=[4 * self._num_units],
+ initializer=(self._bias_initializer
+ if self._bias_initializer is not None else
+ init_ops.zeros_initializer(dtype=self.dtype)))
+ # pylint: enable=protected-access
+
+ self.built = True
+
+ def call(self, inputs, state):
+ """Independent Long short-term memory cell (IndyLSTM).
+
+ Args:
+ inputs: `2-D` tensor with shape `[batch_size, input_size]`.
+ state: An `LSTMStateTuple` of state tensors, each shaped
+ `[batch_size, num_units]`.
+
+ Returns:
+ A pair containing the new hidden state, and the new state (a
+ `LSTMStateTuple`).
+ """
+ sigmoid = math_ops.sigmoid
+ one = constant_op.constant(1, dtype=dtypes.int32)
+ c, h = state
+
+ gate_inputs = math_ops.matmul(inputs, self._kernel_w)
+ gate_inputs += gen_array_ops.tile(h, [1, 4]) * self._kernel_u
+ gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)
+
+ # i = input_gate, j = new_input, f = forget_gate, o = output_gate
+ i, j, f, o = array_ops.split(
+ value=gate_inputs, num_or_size_splits=4, axis=one)
+
+ forget_bias_tensor = constant_op.constant(self._forget_bias, dtype=f.dtype)
+ # Note that using `add` and `multiply` instead of `+` and `*` gives a
+ # performance improvement. So using those at the cost of readability.
+ add = math_ops.add
+ multiply = math_ops.multiply
+ new_c = add(
+ multiply(c, sigmoid(add(f, forget_bias_tensor))),
+ multiply(sigmoid(i), self._activation(j)))
+ new_h = multiply(self._activation(new_c), sigmoid(o))
+
+ new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h)
+ return new_h, new_state
diff --git a/tensorflow/contrib/rpc/python/kernel_tests/BUILD b/tensorflow/contrib/rpc/python/kernel_tests/BUILD
index 2311c15a68..cb0b89ae55 100644
--- a/tensorflow/contrib/rpc/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/rpc/python/kernel_tests/BUILD
@@ -1,5 +1,3 @@
-# TODO(b/76425722): Port everything in here to OS (currently excluded).
-
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
@@ -17,7 +15,6 @@ tf_proto_library(
srcs = ["test_example.proto"],
has_services = 1,
cc_api_version = 2,
- protodeps = ["//tensorflow/core:protos_all"],
)
py_library(
diff --git a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py
index 27273d16b1..1c23c28860 100644
--- a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py
+++ b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py
@@ -51,23 +51,23 @@ class RpcOpTestBase(object):
def testScalarHostPortRpc(self):
with self.test_session() as sess:
request_tensors = (
- test_example_pb2.TestCase(shape=[1, 2, 3]).SerializeToString())
+ test_example_pb2.TestCase(values=[1, 2, 3]).SerializeToString())
response_tensors = self.rpc(
- method=self.get_method_name('IncrementTestShapes'),
+ method=self.get_method_name('Increment'),
address=self._address,
request=request_tensors)
self.assertEqual(response_tensors.shape, ())
response_values = sess.run(response_tensors)
response_message = test_example_pb2.TestCase()
self.assertTrue(response_message.ParseFromString(response_values))
- self.assertAllEqual([2, 3, 4], response_message.shape)
+ self.assertAllEqual([2, 3, 4], response_message.values)
def testScalarHostPortTryRpc(self):
with self.test_session() as sess:
request_tensors = (
- test_example_pb2.TestCase(shape=[1, 2, 3]).SerializeToString())
+ test_example_pb2.TestCase(values=[1, 2, 3]).SerializeToString())
response_tensors, status_code, status_message = self.try_rpc(
- method=self.get_method_name('IncrementTestShapes'),
+ method=self.get_method_name('Increment'),
address=self._address,
request=request_tensors)
self.assertEqual(status_code.shape, ())
@@ -77,7 +77,7 @@ class RpcOpTestBase(object):
sess.run((response_tensors, status_code, status_message)))
response_message = test_example_pb2.TestCase()
self.assertTrue(response_message.ParseFromString(response_values))
- self.assertAllEqual([2, 3, 4], response_message.shape)
+ self.assertAllEqual([2, 3, 4], response_message.values)
# For the base Rpc op, don't expect to get error status back.
self.assertEqual(errors.OK, status_code_values)
self.assertEqual(b'', status_message_values)
@@ -86,7 +86,7 @@ class RpcOpTestBase(object):
with self.test_session() as sess:
request_tensors = []
response_tensors = self.rpc(
- method=self.get_method_name('IncrementTestShapes'),
+ method=self.get_method_name('Increment'),
address=self._address,
request=request_tensors)
self.assertAllEqual(response_tensors.shape, [0])
@@ -95,7 +95,7 @@ class RpcOpTestBase(object):
def testInvalidMethod(self):
for method in [
- '/InvalidService.IncrementTestShapes',
+ '/InvalidService.Increment',
self.get_method_name('InvalidMethodName')
]:
with self.test_session() as sess:
@@ -115,12 +115,12 @@ class RpcOpTestBase(object):
with self.assertRaises(errors.UnavailableError):
sess.run(
self.rpc(
- method=self.get_method_name('IncrementTestShapes'),
+ method=self.get_method_name('Increment'),
address=address,
request=''))
_, status_code_value, status_message_value = sess.run(
self.try_rpc(
- method=self.get_method_name('IncrementTestShapes'),
+ method=self.get_method_name('Increment'),
address=address,
request=''))
self.assertEqual(errors.UNAVAILABLE, status_code_value)
@@ -182,10 +182,10 @@ class RpcOpTestBase(object):
with self.test_session() as sess:
request_tensors = [
test_example_pb2.TestCase(
- shape=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
+ values=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
]
response_tensors = self.rpc(
- method=self.get_method_name('IncrementTestShapes'),
+ method=self.get_method_name('Increment'),
address=self._address,
request=request_tensors)
self.assertEqual(response_tensors.shape, (20,))
@@ -194,17 +194,17 @@ class RpcOpTestBase(object):
for i in range(20):
response_message = test_example_pb2.TestCase()
self.assertTrue(response_message.ParseFromString(response_values[i]))
- self.assertAllEqual([i + 1, i + 2, i + 3], response_message.shape)
+ self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values)
def testVecHostPortManyParallelRpcs(self):
with self.test_session() as sess:
request_tensors = [
test_example_pb2.TestCase(
- shape=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
+ values=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
]
many_response_tensors = [
self.rpc(
- method=self.get_method_name('IncrementTestShapes'),
+ method=self.get_method_name('Increment'),
address=self._address,
request=request_tensors) for _ in range(10)
]
@@ -216,25 +216,25 @@ class RpcOpTestBase(object):
for i in range(20):
response_message = test_example_pb2.TestCase()
self.assertTrue(response_message.ParseFromString(response_values[i]))
- self.assertAllEqual([i + 1, i + 2, i + 3], response_message.shape)
+ self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values)
def testVecHostPortRpcUsingEncodeAndDecodeProto(self):
with self.test_session() as sess:
request_tensors = encode_proto_op.encode_proto(
message_type='tensorflow.contrib.rpc.TestCase',
- field_names=['shape'],
+ field_names=['values'],
sizes=[[3]] * 20,
values=[
[[i, i + 1, i + 2] for i in range(20)],
])
response_tensor_strings = self.rpc(
- method=self.get_method_name('IncrementTestShapes'),
+ method=self.get_method_name('Increment'),
address=self._address,
request=request_tensors)
_, (response_shape,) = decode_proto_op.decode_proto(
bytes=response_tensor_strings,
message_type='tensorflow.contrib.rpc.TestCase',
- field_names=['shape'],
+ field_names=['values'],
output_types=[dtypes.int32])
response_shape_values = sess.run(response_shape)
self.assertAllEqual([[i + 1, i + 2, i + 3]
@@ -285,9 +285,9 @@ class RpcOpTestBase(object):
addresses = flatten([[
self._address, 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@'
] for _ in range(10)])
- request = test_example_pb2.TestCase(shape=[0, 1, 2]).SerializeToString()
+ request = test_example_pb2.TestCase(values=[0, 1, 2]).SerializeToString()
response_tensors, status_code, _ = self.try_rpc(
- method=self.get_method_name('IncrementTestShapes'),
+ method=self.get_method_name('Increment'),
address=addresses,
request=request)
response_tensors_values, status_code_values = sess.run((response_tensors,
@@ -303,9 +303,9 @@ class RpcOpTestBase(object):
flatten = lambda x: list(itertools.chain.from_iterable(x))
with self.test_session() as sess:
methods = flatten(
- [[self.get_method_name('IncrementTestShapes'), 'InvalidMethodName']
+ [[self.get_method_name('Increment'), 'InvalidMethodName']
for _ in range(10)])
- request = test_example_pb2.TestCase(shape=[0, 1, 2]).SerializeToString()
+ request = test_example_pb2.TestCase(values=[0, 1, 2]).SerializeToString()
response_tensors, status_code, _ = self.try_rpc(
method=methods, address=self._address, request=request)
response_tensors_values, status_code_values = sess.run((response_tensors,
@@ -325,10 +325,10 @@ class RpcOpTestBase(object):
] for _ in range(10)])
requests = [
test_example_pb2.TestCase(
- shape=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
+ values=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
]
response_tensors, status_code, _ = self.try_rpc(
- method=self.get_method_name('IncrementTestShapes'),
+ method=self.get_method_name('Increment'),
address=addresses,
request=requests)
response_tensors_values, status_code_values = sess.run((response_tensors,
@@ -343,4 +343,4 @@ class RpcOpTestBase(object):
response_message = test_example_pb2.TestCase()
self.assertTrue(
response_message.ParseFromString(response_tensors_values[i]))
- self.assertAllEqual([i + 1, i + 2, i + 3], response_message.shape)
+ self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values)
diff --git a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_servicer.py b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_servicer.py
index 7cbd636cb1..265254aa51 100644
--- a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_servicer.py
+++ b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_servicer.py
@@ -30,8 +30,8 @@ from tensorflow.contrib.rpc.python.kernel_tests import test_example_pb2_grpc
class RpcOpTestServicer(test_example_pb2_grpc.TestCaseServiceServicer):
"""Test servicer for RpcOp tests."""
- def IncrementTestShapes(self, request, context):
- """Increment the entries in the shape attribute of request.
+ def Increment(self, request, context):
+ """Increment the entries in the `values` attribute of request.
Args:
request: input TestCase.
@@ -40,8 +40,8 @@ class RpcOpTestServicer(test_example_pb2_grpc.TestCaseServiceServicer):
Returns:
output TestCase.
"""
- for i in range(len(request.shape)):
- request.shape[i] += 1
+ for i in range(len(request.values)):
+ request.values[i] += 1
return request
def AlwaysFailWithInvalidArgument(self, request, context):
diff --git a/tensorflow/contrib/rpc/python/kernel_tests/test_example.proto b/tensorflow/contrib/rpc/python/kernel_tests/test_example.proto
index 96f4550f62..8141466349 100644
--- a/tensorflow/contrib/rpc/python/kernel_tests/test_example.proto
+++ b/tensorflow/contrib/rpc/python/kernel_tests/test_example.proto
@@ -1,29 +1,17 @@
// Test description and protos to work with it.
-//
-// Many of the protos in this file are for unit tests that haven't been written yet.
syntax = "proto2";
-import "tensorflow/core/framework/types.proto";
-
package tensorflow.contrib.rpc;
-// A TestCase holds a proto and a bunch of assertions
-// about how it should decode.
+// A TestCase holds a sequence of values.
message TestCase {
- // A batch of primitives to be serialized and decoded.
- repeated RepeatedPrimitiveValue primitive = 1;
- // The shape of the batch.
- repeated int32 shape = 2;
- // Expected sizes for each field.
- repeated int32 sizes = 3;
- // Expected values for each field.
- repeated FieldSpec field = 4;
+ repeated int32 values = 1;
};
service TestCaseService {
- // Copy input, and increment each entry in 'shape' by 1.
- rpc IncrementTestShapes(TestCase) returns (TestCase) {
+ // Copy input, and increment each entry in 'values' by 1.
+ rpc Increment(TestCase) returns (TestCase) {
}
// Sleep forever.
@@ -42,130 +30,3 @@ service TestCaseService {
rpc SometimesFailWithInvalidArgument(TestCase) returns (TestCase) {
}
};
-
-// FieldSpec describes the expected output for a single field.
-message FieldSpec {
- optional string name = 1;
- optional tensorflow.DataType dtype = 2;
- optional RepeatedPrimitiveValue expected = 3;
-};
-
-message TestValue {
- optional PrimitiveValue primitive_value = 1;
- optional EnumValue enum_value = 2;
- optional MessageValue message_value = 3;
- optional RepeatedMessageValue repeated_message_value = 4;
- optional RepeatedPrimitiveValue repeated_primitive_value = 6;
-}
-
-message PrimitiveValue {
- optional double double_value = 1;
- optional float float_value = 2;
- optional int64 int64_value = 3;
- optional uint64 uint64_value = 4;
- optional int32 int32_value = 5;
- optional fixed64 fixed64_value = 6;
- optional fixed32 fixed32_value = 7;
- optional bool bool_value = 8;
- optional string string_value = 9;
- optional bytes bytes_value = 12;
- optional uint32 uint32_value = 13;
- optional sfixed32 sfixed32_value = 15;
- optional sfixed64 sfixed64_value = 16;
- optional sint32 sint32_value = 17;
- optional sint64 sint64_value = 18;
-}
-
-// NOTE: This definition must be kept in sync with PackedPrimitiveValue.
-message RepeatedPrimitiveValue {
- repeated double double_value = 1;
- repeated float float_value = 2;
- repeated int64 int64_value = 3;
- repeated uint64 uint64_value = 4;
- repeated int32 int32_value = 5;
- repeated fixed64 fixed64_value = 6;
- repeated fixed32 fixed32_value = 7;
- repeated bool bool_value = 8;
- repeated string string_value = 9;
- repeated bytes bytes_value = 12;
- repeated uint32 uint32_value = 13;
- repeated sfixed32 sfixed32_value = 15;
- repeated sfixed64 sfixed64_value = 16;
- repeated sint32 sint32_value = 17;
- repeated sint64 sint64_value = 18;
- repeated PrimitiveValue message_value = 19;
-}
-
-// A PackedPrimitiveValue looks exactly the same as a RepeatedPrimitiveValue
-// in the text format, but the binary serializion is different.
-// We test the packed representations by loading the same test cases
-// using this definition instead of RepeatedPrimitiveValue.
-// NOTE: This definition must be kept in sync with RepeatedPrimitiveValue
-// in every way except the packed=true declaration.
-message PackedPrimitiveValue {
- repeated double double_value = 1 [packed = true];
- repeated float float_value = 2 [packed = true];
- repeated int64 int64_value = 3 [packed = true];
- repeated uint64 uint64_value = 4 [packed = true];
- repeated int32 int32_value = 5 [packed = true];
- repeated fixed64 fixed64_value = 6 [packed = true];
- repeated fixed32 fixed32_value = 7 [packed = true];
- repeated bool bool_value = 8 [packed = true];
- repeated string string_value = 9;
- repeated bytes bytes_value = 12;
- repeated uint32 uint32_value = 13 [packed = true];
- repeated sfixed32 sfixed32_value = 15 [packed = true];
- repeated sfixed64 sfixed64_value = 16 [packed = true];
- repeated sint32 sint32_value = 17 [packed = true];
- repeated sint64 sint64_value = 18 [packed = true];
- repeated PrimitiveValue message_value = 19;
-}
-
-message EnumValue {
- enum Color {
- RED = 0;
- ORANGE = 1;
- YELLOW = 2;
- GREEN = 3;
- BLUE = 4;
- INDIGO = 5;
- VIOLET = 6;
- };
- optional Color enum_value = 14;
- repeated Color repeated_enum_value = 15;
-}
-
-
-message InnerMessageValue {
- optional float float_value = 2;
- repeated bytes bytes_values = 8;
-}
-
-message MiddleMessageValue {
- repeated int32 int32_values = 5;
- optional InnerMessageValue message_value = 11;
- optional uint32 uint32_value = 13;
-}
-
-message MessageValue {
- optional double double_value = 1;
- optional MiddleMessageValue message_value = 11;
-}
-
-message RepeatedMessageValue {
- message NestedMessageValue {
- optional float float_value = 2;
- repeated bytes bytes_values = 8;
- }
-
- repeated NestedMessageValue message_values = 11;
-}
-
-// Message containing fields with field numbers higher than any field above. An
-// instance of this message is prepended to each binary message in the test to
-// exercise the code path that handles fields encoded out of order of field
-// number.
-message ExtraFields {
- optional string string_value = 1776;
- optional bool bool_value = 1777;
-}
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
index 178328619f..4073b390fc 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
@@ -132,6 +132,48 @@ class TestGatherTree(test.TestCase):
def test_gather_tree_from_array_2d(self):
self._test_gather_tree_from_array(depth_ndims=2)
+ def test_gather_tree_from_array_complex_trajectory(self):
+ # Max. time = 7, batch = 1, beam = 5.
+ array = np.expand_dims(np.array(
+ [[[25, 12, 114, 89, 97]],
+ [[9, 91, 64, 11, 162]],
+ [[34, 34, 34, 34, 34]],
+ [[2, 4, 2, 2, 4]],
+ [[2, 3, 6, 2, 2]],
+ [[2, 2, 2, 3, 2]],
+ [[2, 2, 2, 2, 2]]]), -1)
+ parent_ids = np.array(
+ [[[0, 0, 0, 0, 0]],
+ [[0, 0, 0, 0, 0]],
+ [[0, 1, 2, 3, 4]],
+ [[0, 0, 1, 2, 1]],
+ [[0, 1, 1, 2, 3]],
+ [[0, 1, 3, 1, 2]],
+ [[0, 1, 2, 3, 4]]])
+ expected_array = np.expand_dims(np.array(
+ [[[25, 25, 25, 25, 25]],
+ [[9, 9, 91, 9, 9]],
+ [[34, 34, 34, 34, 34]],
+ [[2, 4, 2, 4, 4]],
+ [[2, 3, 6, 3, 6]],
+ [[2, 2, 2, 3, 2]],
+ [[2, 2, 2, 2, 2]]]), -1)
+ sequence_length = [[4, 6, 4, 7, 6]]
+
+ array = ops.convert_to_tensor(
+ array, dtype=dtypes.float32)
+ parent_ids = ops.convert_to_tensor(
+ parent_ids, dtype=dtypes.int32)
+ expected_array = ops.convert_to_tensor(
+ expected_array, dtype=dtypes.float32)
+
+ sorted_array = beam_search_decoder.gather_tree_from_array(
+ array, parent_ids, sequence_length)
+
+ with self.test_session() as sess:
+ sorted_array, expected_array = sess.run([sorted_array, expected_array])
+ self.assertAllEqual(expected_array, sorted_array)
+
class TestArrayShapeChecks(test.TestCase):
diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
index c7fbeea310..f17dbb0fe3 100644
--- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
+++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
@@ -145,24 +145,20 @@ def gather_tree_from_array(t, parent_ids, sequence_length):
array_ops.expand_dims(math_ops.range(beam_width), 0), 0)
beam_ids = array_ops.tile(beam_ids, [max_time, batch_size, 1])
- mask = array_ops.sequence_mask(
- sequence_length, maxlen=max_time, dtype=dtypes.int32)
- mask = array_ops.transpose(mask, perm=[2, 0, 1])
-
- # Use beam_width + 1 to mark the end of beam.
- masked_beam_ids = (beam_ids * mask) + (1 - mask) * (beam_width + 1)
-
max_sequence_lengths = math_ops.to_int32(
math_ops.reduce_max(sequence_length, axis=1))
sorted_beam_ids = beam_search_ops.gather_tree(
- step_ids=masked_beam_ids,
+ step_ids=beam_ids,
parent_ids=parent_ids,
max_sequence_lengths=max_sequence_lengths,
end_token=beam_width + 1)
# For out of range steps, simply copy the same beam.
+ in_bound_steps = array_ops.transpose(
+ array_ops.sequence_mask(sequence_length, maxlen=max_time),
+ perm=[2, 0, 1])
sorted_beam_ids = array_ops.where(
- math_ops.cast(mask, dtypes.bool), x=sorted_beam_ids, y=beam_ids)
+ in_bound_steps, x=sorted_beam_ids, y=beam_ids)
# Generate indices for gather_nd.
time_ind = array_ops.tile(array_ops.reshape(
diff --git a/tensorflow/contrib/tensorboard/db/BUILD b/tensorflow/contrib/tensorboard/db/BUILD
index 3f6b4cdc9a..6507546ee9 100644
--- a/tensorflow/contrib/tensorboard/db/BUILD
+++ b/tensorflow/contrib/tensorboard/db/BUILD
@@ -106,6 +106,7 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "//tensorflow/core:png_internal",
"//tensorflow/core:protos_all_cc",
],
)
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
index 189944f29b..6ebc30ca82 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
@@ -86,27 +86,48 @@ bool IsTensorRTCandidate(const tensorflow::Node* node) {
// TODO(jie): Segmentation shouldn't associated with op name.
// Split it into a registration for each kernel.
static const std::set<string> candidate_ops = {
- "Identity",
- "Snapshot",
- "Const",
- "Conv2D",
- "MaxPool",
- "BiasAdd",
- "Relu",
- "Add",
- "Mul",
- "Sub",
- "Rsqrt",
- "Pad",
- "Mean",
- "AvgPool",
- "ConcatV2",
- "DepthwiseConv2dNative",
- "FusedBatchNorm",
- "FusedBatchNormV2",
- // TODO(ben,jie): ...
+ "Identity",
+ "Snapshot",
+ "Const",
+ "Conv2D",
+ "MaxPool",
+ "BiasAdd",
+ "Relu",
+ "Add",
+ "Mul",
+ "Sub",
+ "Rsqrt",
+ "Pad",
+ "Mean",
+ "AvgPool",
+ "ConcatV2",
+ "DepthwiseConv2dNative",
+ "FusedBatchNorm",
+ "FusedBatchNormV2",
+ "Div",
+ "RealDiv",
+ "Rsqrt",
+ "Reciprocal",
+ "Exp",
+ "Log",
+ "Sqrt",
+ "Abs",
+ "Neg",
+#if NV_TENSORRT_MAJOR > 3
+ "MatMul",
+ "BatchMatMul",
+ "Softmax",
+ "Minimum",
+ "Maximum",
+ "TopKV2",
+ "Sum",
+ "Prod",
+ "Max",
+ "Min",
+#endif
+ // TODO(ben,jie): ...
};
- // LINT.ThenChange(//tensorflow/contrib/tensorrt/convert/convert_nodes.h)
+ // LINT.ThenChange(//tensorflow/contrib/tensorrt/convert/convert_nodes.cc)
return (candidate_ops.count(node->type_string()) ||
PluginFactoryTensorRT::GetInstance()->IsPlugin(node->type_string()));
}
@@ -168,7 +189,7 @@ tensorflow::Status ConvertCalibGraphToInferGraph(
"Can't get TRTCalibrator from resource manager!");
}
cres->Unref();
- calib_rm->Cleanup(container_name);
+ TF_RETURN_IF_ERROR(calib_rm->Cleanup(container_name));
}
}
return tensorflow::Status::OK();
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
index 146b9c7344..0ee708bc1c 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
@@ -49,9 +49,29 @@ limitations under the License.
#if GOOGLE_TENSORRT
#include "tensorrt/include/NvInfer.h"
-// Check if the types are equal. Cast to int first so that failure log message
-// would work!
-#define CHECK_EQ_TYPE(val1, val2) CHECK_EQ((int)val1, (int)val2)
+// Check if the types are equal. Cast to int first so that failure log message
+// would work!
+#define TFTRT_CHECK_EQ_TYPE(val1, val2) CHECK_EQ((int)val1, (int)val2)
+
+#define TFTRT_INTERNAL_ERROR_AT_NODE(node) \
+ do { \
+ return tensorflow::errors::Internal( \
+ "TFTRT::", __FUNCTION__, "failed to add TRT layer, at: ", node); \
+ } while (0)
+
+#define TFTRT_RETURN_ERROR_IF_FALSE(status, node) \
+ do { \
+ if (status == false) { \
+ TFTRT_INTERNAL_ERROR_AT_NODE(node); \
+ } \
+ } while (0)
+
+#define TFTRT_RETURN_ERROR_IF_NULLPTR(ptr, node) \
+ do { \
+ if (ptr == nullptr) { \
+ TFTRT_INTERNAL_ERROR_AT_NODE(node); \
+ } \
+ } while (0)
namespace tensorflow {
namespace tensorrt {
@@ -75,13 +95,110 @@ inline tensorflow::Status ConvertDType(tensorflow::DataType tf_dtype,
case tensorflow::DataType::DT_HALF:
*trt_dtype = nvinfer1::DataType::kHALF;
break;
+#if NV_TENSORRT_MAJOR > 3
+ case tensorflow::DataType::DT_INT32:
+ *trt_dtype = nvinfer1::DataType::kINT32;
+ break;
+#endif
default:
return tensorflow::errors::InvalidArgument(
- "Unsupported data type " + tensorflow::DataTypeString(tf_dtype));
+ "Unsupported data type ", tensorflow::DataTypeString(tf_dtype));
}
return tensorflow::Status::OK();
}
+// Return whether or not the broadcast is feasible;
+bool TensorRTGetBroadcastShape(const nvinfer1::Dims& operand_l,
+ const bool operand_l_is_tensor,
+ const nvinfer1::Dims& operand_r,
+ const bool operand_r_is_tensor,
+ nvinfer1::Dims* operand_l_new_shape,
+ nvinfer1::Dims* operand_r_new_shape) {
+ // ***************************************************************************
+ // TensorRT Elementwise op supports broadcast but requires both tensor to be
+ // of Identical rank
+ //
+ // We consider case of:
+ // 1. operand_l to be a Tensor & operand_r to be a Const;
+ // 2. operand_l to be a Tensor & operand_r to be a Tensor;
+ // note: const op const (constant folding) should fallback to TensorFlow
+ //
+ // broadcast scheme:
+ // T: 1 3 5 (tensor would not have batch dimension)
+ // W: 1 1 3 1 (weight would have all explicit dimensions)
+ // i. fill in explicit dimensions
+ // -> T: -1 1 3 5 (we put a -1 for batch dimension)
+ // -> W: 1 1 3 1
+ // ii. compare broadcast feasibility
+ //
+ // We cannot support the following since TensorRT does not allow manipulation
+ // on batch dimension, we cannot generate output with proper shape
+ // T: 3 5 1
+ // W: 1 1 1 1 3 5 1
+ // -> T: 1 1 1 -1 3 5 1
+ // -> W: 1 1 1 1 3 5 1
+ // ***************************************************************************
+ const int max_nb_dims = nvinfer1::Dims::MAX_DIMS + 1;
+ const size_t element_size = sizeof(operand_l.d[0]);
+
+ // fill in dimensions
+ int l_s[max_nb_dims];
+ std::fill(l_s, l_s + max_nb_dims, 1);
+ int l_d = operand_l_is_tensor ? operand_l.nbDims + 1 : operand_l.nbDims;
+ int r_s[max_nb_dims];
+ std::fill(r_s, r_s + max_nb_dims, 1);
+ int r_d = operand_r_is_tensor ? operand_r.nbDims + 1 : operand_r.nbDims;
+
+ int max_d = std::max(l_d, r_d);
+ std::memcpy(l_s + max_d - operand_l.nbDims, operand_l.d,
+ operand_l.nbDims * element_size);
+ std::memcpy(r_s + max_d - operand_r.nbDims, operand_r.d,
+ operand_r.nbDims * element_size);
+
+ // set -1 for batch dimension, since batch size is not supposed to be
+ // broadcasted
+ if (operand_l_is_tensor) {
+ if (max_d != l_d) { // if broadcast beyond batch dimension, fail
+ return false;
+ }
+ l_s[0] = -1;
+ }
+ if (operand_r_is_tensor) {
+ if (max_d != r_d) { // if broadcast beyond batch dimension, fail
+ return false;
+ }
+ r_s[0] = -1;
+ }
+
+ // compare broadcast feasibility
+ for (int i = max_d - 1; i >= 0; i--) {
+ if ((l_s[i] != r_s[i]) && (l_s[i] != 1) && (r_s[i] != 1)) {
+ return false;
+ }
+ }
+
+ // output new TensorRT Dimension (stripping the batch dimension)
+ operand_l_new_shape->nbDims = max_d - 1;
+ std::memcpy(operand_l_new_shape->d, l_s + 1, (max_d - 1) * element_size);
+ operand_r_new_shape->nbDims = max_d - 1;
+ std::memcpy(operand_r_new_shape->d, r_s + 1, (max_d - 1) * element_size);
+
+ return true;
+}
+
+inline bool DimsEqual(const nvinfer1::Dims& dim_l,
+ const nvinfer1::Dims& dim_r) {
+ if (dim_l.nbDims != dim_r.nbDims) {
+ return false;
+ }
+ for (int i = 0; i < dim_l.nbDims; i++) {
+ if (dim_l.d[i] != dim_r.d[i]) {
+ return false;
+ }
+ }
+ return true;
+}
+
inline nvinfer1::Dims GetTensorShape(const tensorflow::Tensor& tensor) {
nvinfer1::Dims dims;
dims.nbDims = tensor.dims();
@@ -91,7 +208,7 @@ inline nvinfer1::Dims GetTensorShape(const tensorflow::Tensor& tensor) {
return dims;
}
-inline int64_t GetShapeSize(nvinfer1::Dims shape) {
+inline int64_t GetShapeSize(const nvinfer1::Dims& shape) {
// Returns total number of elements in shape
int64_t count = 1;
for (int d = 0; d < shape.nbDims; ++d) {
@@ -104,7 +221,7 @@ static std::vector<std::pair<int, int>> CreateSamePadding(
const nvinfer1::DimsHW& stride, const nvinfer1::DimsHW& kernel,
const std::vector<int64_t>& input_dims) {
std::vector<std::pair<int, int>> padding(input_dims.size());
- CHECK_EQ((size_t)stride.nbDims, input_dims.size()); // TODO(jie): N+C? NC+?
+ CHECK_EQ(stride.nbDims, input_dims.size()); // TODO(jie): N+C? NC+?
for (size_t i = 0; i < input_dims.size(); ++i) {
// Formula to calculate the padding
@@ -134,6 +251,7 @@ string GetCommonNameScope(const string& op_name_a, const string& op_name_b) {
return op_name_a.substr(0, last_scope_separator);
}
+// Class to convert TF weight to TRT weight.
class TRT_ShapedWeights {
public:
TRT_ShapedWeights(tensorflow::DataType type, const void* values,
@@ -145,12 +263,14 @@ class TRT_ShapedWeights {
explicit TRT_ShapedWeights(tensorflow::DataType type)
: shape_(), type_(type), values_(nullptr), empty_weight_flag_(true) {}
+ // TODO(aaroey): use rvalue reference.
TRT_ShapedWeights(const TRT_ShapedWeights& rhs)
: shape_(rhs.shape_),
type_(rhs.type_),
values_(rhs.values_),
empty_weight_flag_(rhs.empty_weight_flag_) {}
+ // TODO(aaroey): use GetShapeSize() instead.
int64_t count() const {
int64_t c = 1;
for (int i = 0; i < shape_.nbDims; i++) c *= shape_.d[i];
@@ -168,6 +288,7 @@ class TRT_ShapedWeights {
const void* GetValues() const { return values_; }
+ // TODO(aaroey): get rid of this method.
void SetValues(const void* values) { values_ = values; }
size_t size_bytes() const {
@@ -178,10 +299,12 @@ class TRT_ShapedWeights {
// Default converter
operator nvinfer1::Weights() const { return GetWeightsForTRT(); }
+ // TODO(aaroey): make these private.
nvinfer1::Dims shape_;
tensorflow::DataType type_;
private:
+ // TODO(aaroey): this should not be const as it's always from TRTWeightStore.
const void* values_;
bool empty_weight_flag_;
};
@@ -192,6 +315,7 @@ class TRT_TensorOrWeights {
: tensor_(tensor), weights_(DT_FLOAT), variant_(TRT_NODE_TENSOR) {}
explicit TRT_TensorOrWeights(const TRT_ShapedWeights& weights)
: tensor_(nullptr), weights_(weights), variant_(TRT_NODE_WEIGHTS) {}
+ // TODO(aaroey): use rvalue reference.
TRT_TensorOrWeights(const TRT_TensorOrWeights& rhs)
: tensor_(rhs.tensor_), weights_(rhs.weights_), variant_(rhs.variant_) {}
~TRT_TensorOrWeights() {}
@@ -200,19 +324,19 @@ class TRT_TensorOrWeights {
bool is_weights() const { return variant_ == TRT_NODE_WEIGHTS; }
nvinfer1::ITensor* tensor() {
- CHECK_EQ(is_tensor(), true);
+ CHECK(is_tensor());
return tensor_;
}
const nvinfer1::ITensor* tensor() const {
- CHECK_EQ(is_tensor(), true);
+ CHECK(is_tensor());
return tensor_;
}
TRT_ShapedWeights& weights() {
- CHECK_EQ(is_weights(), true);
+ CHECK(is_weights());
return weights_;
}
const TRT_ShapedWeights& weights() const {
- CHECK_EQ(is_weights(), true);
+ CHECK(is_weights());
return weights_;
}
nvinfer1::Dims shape() const {
@@ -236,21 +360,25 @@ class TFAttrs {
attrs_.insert({attr.first, &attr.second});
}
}
- bool count(string key) const { return attrs_.count(key); }
- tensorflow::AttrValue const* at(string key) const {
+
+ bool count(const string& key) const { return attrs_.count(key); }
+
+ tensorflow::AttrValue const* at(const string& key) const {
if (!attrs_.count(key)) {
LOG(FATAL) << "Attribute not found: " << key;
}
return attrs_.at(key);
}
+
template <typename T>
T get(const string& key) const;
+
template <typename T>
T get(const string& key, const T& default_value) const {
return attrs_.count(key) ? this->get<T>(key) : default_value;
}
- std::vector<string> GetAllAttrKey() {
+ std::vector<string> GetAllAttrKeys() const {
std::vector<string> attr_list;
for (const auto& attr_item : attrs_) {
attr_list.emplace_back(attr_item.first);
@@ -285,15 +413,6 @@ std::vector<string> TFAttrs::get<std::vector<string>>(const string& key) const {
auto attr = this->at(key)->list().s();
return std::vector<string>(attr.begin(), attr.end());
}
-template <>
-nvinfer1::Dims TFAttrs::get<nvinfer1::Dims>(const string& key) const {
- auto values = this->get<std::vector<int>>(key);
- nvinfer1::Dims dims;
- dims.nbDims = values.size();
- std::copy(values.begin(), values.end(), dims.d);
- // Note: No dimension type information is included
- return dims;
-}
template <>
nvinfer1::DataType TFAttrs::get<nvinfer1::DataType>(const string& key) const {
@@ -319,10 +438,11 @@ bool TFAttrs::get<bool>(const string& key) const {
}
// TODO(jie): reorder4 & reorder2 should be merged?
+// TODO(aaroey): fix the order of parameters.
template <typename T>
-void Reorder4(nvinfer1::DimsNCHW shape, const T* idata,
- nvinfer1::DimsNCHW istrides, T* odata,
- nvinfer1::DimsNCHW ostrides) {
+void Reorder4(const nvinfer1::DimsNCHW& shape, const T* idata,
+ const nvinfer1::DimsNCHW& istrides, T* odata,
+ const nvinfer1::DimsNCHW& ostrides) {
for (int n = 0; n < shape.n(); ++n) {
for (int c = 0; c < shape.c(); ++c) {
for (int h = 0; h < shape.h(); ++h) {
@@ -337,12 +457,13 @@ void Reorder4(nvinfer1::DimsNCHW shape, const T* idata,
}
template <typename T>
-void Reorder2(nvinfer1::DimsHW shape, const T* idata, nvinfer1::DimsHW istrides,
- T* odata, nvinfer1::DimsHW ostrides) {
+void Reorder2(const nvinfer1::DimsHW& shape, const T* idata,
+ const nvinfer1::DimsHW& istrides, T* odata,
+ const nvinfer1::DimsHW& ostrides) {
for (int h = 0; h < shape.h(); ++h) {
for (int w = 0; w < shape.w(); ++w) {
odata[h * ostrides.h() + w * ostrides.w()] =
- idata[h * ostrides.h() + w * ostrides.w()];
+ idata[h * istrides.h() + w * istrides.w()];
}
}
}
@@ -350,16 +471,17 @@ void Reorder2(nvinfer1::DimsHW shape, const T* idata, nvinfer1::DimsHW istrides,
// TODO(jie): fallback to tensorflow!!
void ReorderCKtoKC(const TRT_ShapedWeights& iweights,
TRT_ShapedWeights* oweights) {
- int c = iweights.shape_.d[0];
- int k = iweights.shape_.d[1];
+ const int c = iweights.shape_.d[0];
+ const int k = iweights.shape_.d[1];
oweights->shape_.d[0] = k;
oweights->shape_.d[1] = c;
- nvinfer1::DimsHW istrides = {1, k};
- nvinfer1::DimsHW ostrides = {c, 1};
+ const nvinfer1::DimsHW istrides = {1, k};
+ const nvinfer1::DimsHW ostrides = {c, 1};
switch (iweights.type_) {
case tensorflow::DataType::DT_FLOAT: {
Reorder2({k, c}, static_cast<float const*>(iweights.GetValues()),
istrides,
+ // TODO(aaroey): get rid of all the const_cast like this.
static_cast<float*>(const_cast<void*>(oweights->GetValues())),
ostrides);
break;
@@ -382,21 +504,24 @@ void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights,
TRT_ShapedWeights* oweights, int num_groups) {
CHECK_EQ(iweights.type_, oweights->type_);
CHECK_EQ(iweights.size_bytes(), oweights->size_bytes());
- int r = iweights.shape_.d[0];
- int s = iweights.shape_.d[1];
- // TRT requires GKcRS, while TF depthwise has RSCK
- // where c=1, C=G
+ // K indexes over output channels, C over input channels, and R and S over the
+ // height and width of the convolution
+ const int r = iweights.shape_.d[0];
+ const int s = iweights.shape_.d[1];
+ // TRT requires GKcRS, while TF depthwise has RSCK where c=1, C=G
VLOG(2) << "num_groups: " << num_groups;
- int c = iweights.shape_.d[2] / num_groups;
+ const int c = iweights.shape_.d[2] / num_groups;
VLOG(2) << "c" << iweights.shape_.d[2] << " then " << c;
- int k = iweights.shape_.d[3] * num_groups;
+ const int k = iweights.shape_.d[3] * num_groups;
VLOG(2) << "k" << iweights.shape_.d[3] << " then " << k;
+ VLOG(2) << "r" << iweights.shape_.d[0] << " then " << r;
+ VLOG(2) << "s" << iweights.shape_.d[1] << " then " << s;
oweights->shape_.d[0] = k / num_groups;
oweights->shape_.d[1] = c * num_groups;
oweights->shape_.d[2] = r;
oweights->shape_.d[3] = s;
- nvinfer1::DimsNCHW istrides = {1, k, s * k * c, c * k};
- nvinfer1::DimsNCHW ostrides = {c * r * s, r * s, s, 1};
+ const nvinfer1::DimsNCHW istrides = {1, k, s * k * c, c * k};
+ const nvinfer1::DimsNCHW ostrides = {c * r * s, r * s, s, 1};
switch (iweights.type_) {
case tensorflow::DataType::DT_FLOAT: {
Reorder4({k, c, r, s}, static_cast<float const*>(iweights.GetValues()),
@@ -428,11 +553,14 @@ using OpConverter =
std::vector<TRT_TensorOrWeights>*)>;
class Converter {
+ // TODO(aaroey): fix the order of members.
std::unordered_map<string, TRT_TensorOrWeights> trt_tensors_;
std::unordered_map<string, OpConverter> op_registry_;
OpConverter plugin_converter_;
nvinfer1::INetworkDefinition* trt_network_;
std::list<std::vector<uint8_t>> temp_bufs_;
+ // TODO(aaroey): inline the definition of TRTWeightStore here, and add APIs to
+ // operate the stored weights instead of operating it directly.
TRTWeightStore* weight_store_;
bool fp16_;
void register_op_converters();
@@ -440,7 +568,7 @@ class Converter {
std::vector<TRT_TensorOrWeights>* inputs) {
for (auto const& input_name : node_def.input()) {
/*************************************************************************
- * TODO(jie) handle case 1) here
+ * TODO(jie): handle case 1) here.
* Normalizes the inputs and extracts associated metadata:
* 1) Inputs can contain a colon followed by a suffix of characters.
* That suffix may be a single number (e.g. inputName:1) or several
@@ -454,6 +582,7 @@ class Converter {
if (input_name[0] == '^') continue;
string name = input_name;
auto first = name.find_first_of(':');
+ // TODO(aaroey): why removing the colon but not the zero? A bug?
if (first != string::npos && first + 2 == name.size() &&
name[first + 1] == '0')
name.erase(first);
@@ -462,12 +591,13 @@ class Converter {
if (trt_tensors_.count(name)) {
inputs->push_back(trt_tensors_.at(name));
} else {
- string str("Node ");
- StrAppend(&str, node_def.name(), " should have an input named '", name,
+ // TODO(aaroey): this should not happen, make it a CHECK.
+ // TODO(aaroey): use StrCat for pattern like this.
+ string msg("Node ");
+ StrAppend(&msg, node_def.name(), " should have an input named '", name,
"' but it is not available");
- LOG(WARNING) << "input: " << name << " not available for node at "
- << node_def.name();
- return tensorflow::errors::InvalidArgument(str);
+ LOG(ERROR) << msg;
+ return tensorflow::errors::InvalidArgument(msg);
}
}
return tensorflow::Status::OK();
@@ -488,6 +618,7 @@ class Converter {
weights.SetValues(weight_store_->store_.back().data());
return weights;
}
+ // TODO(aaroey): fix all the namings.
bool isFP16() { return fp16_; }
TRT_ShapedWeights get_temp_weights_like(const TRT_ShapedWeights& weights) {
return this->get_temp_weights(weights.type_, weights.shape_);
@@ -496,7 +627,7 @@ class Converter {
tensorflow::Status convert_node(const tensorflow::NodeDef& node_def) {
std::vector<TRT_TensorOrWeights> inputs;
TF_RETURN_IF_ERROR(this->get_inputs(node_def, &inputs));
- string op = node_def.op();
+ const string& op = node_def.op();
std::vector<TRT_TensorOrWeights> outputs;
if (PluginFactoryTensorRT::GetInstance()->IsPlugin(op)) {
TF_RETURN_IF_ERROR(plugin_converter_(*this, node_def, inputs, &outputs));
@@ -509,7 +640,7 @@ class Converter {
TF_RETURN_IF_ERROR(op_converter(*this, node_def, inputs, &outputs));
}
for (size_t i = 0; i < outputs.size(); ++i) {
- TRT_TensorOrWeights output = outputs.at(i);
+ TRT_TensorOrWeights& output = outputs[i];
// TODO(jie): tf protobuf seems to be omitting the :0 suffix
string output_name = node_def.name();
if (i != 0) output_name = StrCat(output_name, ":", i);
@@ -527,26 +658,29 @@ class Converter {
nvinfer1::INetworkDefinition* network() { return trt_network_; }
- TRT_TensorOrWeights get_tensor(string name) {
+ TRT_TensorOrWeights get_tensor(const string& name) {
if (!trt_tensors_.count(name)) {
return TRT_TensorOrWeights(nullptr);
}
return trt_tensors_.at(name);
}
- bool insert_input_tensor(string name, nvinfer1::ITensor* tensor) {
+ bool insert_input_tensor(const string& name, nvinfer1::ITensor* tensor) {
return trt_tensors_.insert({name, TRT_TensorOrWeights(tensor)}).second;
}
nvinfer1::ITensor* TransposeTensor(nvinfer1::ITensor* input_tensor,
- std::vector<int> order) {
- auto dims = input_tensor->getDimensions();
+ const std::vector<int>& order) {
+ const auto dims = input_tensor->getDimensions();
// TODO(jie): change the return to status and properly exit
if (order.size() - 1 != size_t(dims.nbDims))
LOG(ERROR) << "Dimension does not match, fail gracefully";
nvinfer1::IShuffleLayer* layer = this->network()->addShuffle(*input_tensor);
+ if (layer == nullptr) {
+ return nullptr;
+ }
nvinfer1::Permutation permutation;
for (int32_t i = 0; i < dims.nbDims; ++i) {
permutation.order[i] = order[i + 1] - 1;
@@ -577,13 +711,14 @@ TRT_ShapedWeights ConvertFP32ToFP16(Converter& ctx,
}
return weights;
}
+
// ****************************************************************************
// Constant folding functions
// TODO(jie): once optimizer kicks in, we should have done constant folding
// there.
-//*****************************************************************************/
+// *****************************************************************************
struct LambdaFactory {
- enum class OP_CATEGORY : int { RSQRT = 0, NEG, ADD, MUL, SUB };
+ enum class OP_CATEGORY : int { RSQRT = 0, NEG, ADD, MUL, SUB, RECIP };
OP_CATEGORY op;
template <typename T>
@@ -595,6 +730,8 @@ struct LambdaFactory {
}
case OP_CATEGORY::NEG:
return [](T t) -> T { return -t; };
+ case OP_CATEGORY::RECIP:
+ return [](T t) -> T { return 1.0 / t; };
default:
VLOG(2) << "Not supported op for unary: " << static_cast<int>(op);
return nullptr;
@@ -628,7 +765,6 @@ struct LambdaFactory {
VLOG(2) << "LAMBDA VAL : " << val;
return l + val;
};
- // Return [val](T l)-> T {return l+val;};
case OP_CATEGORY::SUB:
return [val](T l) -> T {
VLOG(2) << "LAMBDA VAL : " << val;
@@ -688,11 +824,13 @@ std::function<Eigen::half(Eigen::half)> LambdaFactory::unary<Eigen::half>() {
}
case OP_CATEGORY::NEG:
return [](Eigen::half t) -> Eigen::half { return -t; };
+ // TODO(aaroey): can we support RECIP?
default:
VLOG(2) << "Not supported op for unary: " << static_cast<int>(op);
return nullptr;
}
}
+
tensorflow::Status UnaryCompute(const TRT_ShapedWeights& iweights,
TRT_ShapedWeights* oweights,
LambdaFactory unary_op) {
@@ -738,6 +876,7 @@ tensorflow::Status BinaryCompute(const TRT_ShapedWeights& iweights_l,
if (iweights_l.count() != iweights_r.count()) {
// We only supports broadcast of RankZero
if (iweights_l.count() == 1) {
+ // TODO(aaroey): Remove loggings like this.
VLOG(2) << "I bet it is not working!" << (*inp_l);
std::transform(inp_r, inp_r + iweights_r.count(), oup,
binary_op.broadcast_l<float>(*inp_l));
@@ -790,117 +929,21 @@ tensorflow::Status BinaryCompute(const TRT_ShapedWeights& iweights_l,
return tensorflow::Status::OK();
}
-tensorflow::Status ConstantFoldUnary(
- Converter& ctx, const tensorflow::NodeDef& node_def,
- const std::vector<TRT_TensorOrWeights>& inputs,
- std::vector<TRT_TensorOrWeights>* outputs) {
- TRT_ShapedWeights weights_input = inputs.at(0).weights();
-
- // Allocate output weights
- TRT_ShapedWeights weights_output = ctx.get_temp_weights_like(weights_input);
-
- // FIXME assume type matches input weights
- // Get trt type & shape
- // Maybe this part has to be moved into the block of rsqrt later
- // Check type consistency
- CHECK_EQ(weights_input.type_,
- TFAttrs(node_def).get<tensorflow::DataType>("T"));
-
- LambdaFactory unary_op;
- if (node_def.op() == "Rsqrt") {
- // Compute rsqrt
- unary_op.op = LambdaFactory::OP_CATEGORY::RSQRT;
- auto ret = UnaryCompute(weights_input, &weights_output, unary_op);
- // Pass the output
- if (ret == tensorflow::Status::OK()) {
- outputs->push_back(TRT_TensorOrWeights(weights_output));
- }
- return ret;
- } else {
- return tensorflow::errors::Unimplemented("Binary op not supported: " +
- node_def.op());
- }
-}
-
-// TODO(jie,ben) broadcast is needed yet not implemented
-// Let's get the simple stuff working first. Maybe we should fall back to TF
-// approach for constant folding
-tensorflow::Status ConstantFoldBinary(
- Converter& ctx, const tensorflow::NodeDef& node_def,
- const std::vector<TRT_TensorOrWeights>& inputs,
- std::vector<TRT_TensorOrWeights>* outputs) {
- TRT_ShapedWeights weights_input_l = inputs.at(0).weights();
- TRT_ShapedWeights weights_input_r = inputs.at(1).weights();
-
- // Check type consistency
- CHECK_EQ(weights_input_l.type_, weights_input_r.type_);
-
- if (weights_input_l.shape_.nbDims != weights_input_r.shape_.nbDims)
- return tensorflow::errors::Unimplemented(
- "Binary op implicit broadcast not supported: " + node_def.op());
-
- // TODO(jie): constant fold should really fall back to TF.
- int num_dims = weights_input_l.shape_.nbDims;
- nvinfer1::Dims output_shape;
- output_shape.nbDims = num_dims;
- VLOG(2) << "nb_dims: " << num_dims
- << ", the other: " << weights_input_r.shape_.nbDims;
- for (int i = 0; i < num_dims; i++) {
- if (weights_input_l.shape_.d[i] == weights_input_r.shape_.d[i]) {
- output_shape.d[i] = weights_input_l.shape_.d[i];
- } else if (weights_input_l.shape_.d[i] == 1 ||
- weights_input_r.shape_.d[i] == 1) {
- output_shape.d[i] =
- std::max(weights_input_l.shape_.d[i], weights_input_r.shape_.d[i]);
- } else {
- return tensorflow::errors::Unimplemented(
- "Binary op with incompatible shape at, " + node_def.op());
- }
- VLOG(2) << "left: " << weights_input_l.shape_.d[i]
- << "right: " << weights_input_r.shape_.d[i]
- << "output: " << output_shape.d[i];
- }
-
- // FIXME assume type matches input weights
- // Get trt type & shape
- TFAttrs attrs(node_def);
- // Maybe this part has to be moved into the block of rsqrt later
- tensorflow::DataType dtype = attrs.get<tensorflow::DataType>("T");
-
- // Allocate output weights
- TRT_ShapedWeights weights_output = ctx.get_temp_weights(dtype, output_shape);
-
- LambdaFactory binary_op;
- if (node_def.op() == "Sub") {
- binary_op.op = LambdaFactory::OP_CATEGORY::SUB;
- } else if (node_def.op() == "Mul") {
- binary_op.op = LambdaFactory::OP_CATEGORY::MUL;
- } else if (node_def.op() == "Add") {
- binary_op.op = LambdaFactory::OP_CATEGORY::ADD;
- } else {
- return tensorflow::errors::Unimplemented("Binary op not supported: " +
- node_def.op());
- }
- auto ret = BinaryCompute(weights_input_l, weights_input_r, &weights_output,
- binary_op);
-
- // Pass the output
- if (ret == tensorflow::Status::OK()) {
- outputs->push_back(TRT_TensorOrWeights(weights_output));
- }
-
- return ret;
-}
-
// TODO(jie): broadcast is needed yet not implemented.
// Only implemented channel wise for the time being
tensorflow::Status BinaryTensorOpWeight(
Converter& ctx, const tensorflow::NodeDef& node_def,
const nvinfer1::ITensor* tensor, TRT_ShapedWeights weights,
- std::vector<TRT_TensorOrWeights>* outputs) {
- // FIXME assume type matches input weights
- // Get trt type & shape
- // Maybe this part has to be moved into the block of rsqrt later
+ bool swapped_inputs, std::vector<TRT_TensorOrWeights>* outputs) {
+ // tensor is the left operand while weights is the right operand;
+ // when swapped_inputs set to true, those two are swapped.
+ // TODO(aaroey): use a set.
+ if (node_def.op() != "Sub" && node_def.op() != "Add" &&
+ node_def.op() != "Mul" && node_def.op() != "Div" &&
+ node_def.op() != "RealDiv") {
+ return tensorflow::errors::Unimplemented(
+ "op not supported: " + node_def.op() + ", at: " + node_def.name());
+ }
// Check type consistency
nvinfer1::DataType ttype;
@@ -910,6 +953,12 @@ tensorflow::Status BinaryTensorOpWeight(
auto dims_w = weights.shape_;
auto dims_t = tensor->getDimensions();
+ // TODO(jie): addScale checks for input tensor dimension
+ if (dims_t.nbDims != 3) {
+ return tensorflow::errors::InvalidArgument(
+ "addScale requires tensor with rank 3, " + node_def.name());
+ }
+
// default to element-wise
auto scale_mode = nvinfer1::ScaleMode::kELEMENTWISE;
@@ -980,6 +1029,7 @@ tensorflow::Status BinaryTensorOpWeight(
permutation[dims_t.nbDims] = 1;
tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
permutation);
+ TFTRT_RETURN_ERROR_IF_NULLPTR(tensor, node_def.name());
} else {
return tensorflow::errors::InvalidArgument(
"Transpose cannot be applied, " + node_def.name());
@@ -997,11 +1047,35 @@ tensorflow::Status BinaryTensorOpWeight(
// Maybe I should do a switch
if (node_def.op() == "Sub") {
- TRT_ShapedWeights neg_weights = ctx.get_temp_weights_like(weights);
- LambdaFactory unary_op;
- unary_op.op = LambdaFactory::OP_CATEGORY::NEG;
- TF_RETURN_IF_ERROR(UnaryCompute(weights, &neg_weights, unary_op));
- shift_weights = neg_weights;
+ if (swapped_inputs) {
+ shift_weights = weights;
+ nvinfer1::IUnaryLayer* layer =
+ ctx.network()->addUnary(*const_cast<nvinfer1::ITensor*>(tensor),
+ nvinfer1::UnaryOperation::kNEG);
+ TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
+ tensor = layer->getOutput(0);
+ } else {
+ TRT_ShapedWeights neg_weights = ctx.get_temp_weights_like(weights);
+ LambdaFactory unary_op;
+ unary_op.op = LambdaFactory::OP_CATEGORY::NEG;
+ TF_RETURN_IF_ERROR(UnaryCompute(weights, &neg_weights, unary_op));
+ shift_weights = neg_weights;
+ }
+ } else if (node_def.op() == "Div" || node_def.op() == "RealDiv") {
+ if (swapped_inputs) {
+ scale_weights = weights;
+ nvinfer1::IUnaryLayer* layer =
+ ctx.network()->addUnary(*const_cast<nvinfer1::ITensor*>(tensor),
+ nvinfer1::UnaryOperation::kRECIP);
+ TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
+ tensor = layer->getOutput(0);
+ } else {
+ TRT_ShapedWeights recip_weights = ctx.get_temp_weights_like(weights);
+ LambdaFactory unary_op;
+ unary_op.op = LambdaFactory::OP_CATEGORY::RECIP;
+ TF_RETURN_IF_ERROR(UnaryCompute(weights, &recip_weights, unary_op));
+ scale_weights = recip_weights;
+ }
} else if (node_def.op() == "Mul") {
scale_weights = weights;
} else if (node_def.op() == "Add") {
@@ -1014,11 +1088,13 @@ tensorflow::Status BinaryTensorOpWeight(
nvinfer1::IScaleLayer* layer = ctx.network()->addScale(
*const_cast<nvinfer1::ITensor*>(tensor), scale_mode, shift_weights,
scale_weights, power_weights);
+ TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
nvinfer1::ITensor* output_tensor = layer->getOutput(0);
// transpose back dimension
if (permutation_flag) {
output_tensor = ctx.TransposeTensor(output_tensor, permutation);
+ TFTRT_RETURN_ERROR_IF_NULLPTR(output_tensor, node_def.name());
}
// Pass the output
@@ -1042,20 +1118,31 @@ tensorflow::Status ConvertConv2DHelper(
if (data_format == "NHWC") {
tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
{0, 3, 1, 2});
+ TFTRT_RETURN_ERROR_IF_NULLPTR(tensor, node_def.name());
h_index = 1;
w_index = 2;
// TODO(jie): transpose it
}
// tensor after transpose (NCHW)
- auto tensor_dim = tensor->getDimensions();
+ const auto tensor_dim = tensor->getDimensions();
int num_groups = group;
- if (num_groups == 0) // depthwise convolution
- num_groups = tensor_dim.d[0];
+ if (num_groups == 0) num_groups = tensor_dim.d[0]; // depthwise convolution
VLOG(2) << "groups count: " << num_groups;
TRT_ShapedWeights weights_rsck = inputs.at(1).weights();
+
+ VLOG(2) << "weight shape: " << weights_rsck.shape_.nbDims;
+ for (int i = 0; i < weights_rsck.shape_.nbDims; i++) {
+ VLOG(2) << weights_rsck.shape_.d[i];
+ }
+
+ if (weights_rsck.shape_.nbDims != 4) {
+ return tensorflow::errors::Internal(
+ "Conv2D expects kernel of dimension 4, at: " + node_def.name());
+ }
+
if (ctx.isFP16()) {
weights_rsck = ConvertFP32ToFP16(ctx, inputs.at(1).weights());
}
@@ -1063,18 +1150,22 @@ tensorflow::Status ConvertConv2DHelper(
TRT_ShapedWeights weights = ctx.get_temp_weights_like(weights_rsck);
ReorderRSCKToKCRS(weights_rsck, &weights, num_groups);
TRT_ShapedWeights biases(weights.type_);
- int noutput = weights.shape_.d[0] * num_groups;
+ const int noutput = weights.shape_.d[0] * num_groups;
nvinfer1::DimsHW kernel_size;
kernel_size.h() = weights.shape_.d[2];
kernel_size.w() = weights.shape_.d[3];
+ VLOG(2) << "RSCK: ";
+ for (int i = 0; i < 4; i++) {
+ VLOG(2) << " " << weights.shape_.d[i];
+ }
VLOG(2) << "kernel size: " << kernel_size.h() << ", " << kernel_size.w();
// TODO(jie): stride. (NHWC/NCHW)
- auto tf_stride = attrs.get<std::vector<int>>("strides");
+ const auto tf_stride = attrs.get<std::vector<int>>("strides");
VLOG(2) << "h_INDEX" << h_index << ", w_index " << w_index;
VLOG(2) << "stride!!!: " << tf_stride[0] << tf_stride[1] << tf_stride[2]
<< tf_stride[3];
- nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]);
+ const nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]);
std::vector<std::pair<int, int>> padding;
// TODO(jie): padding.
@@ -1102,6 +1193,7 @@ tensorflow::Status ConvertConv2DHelper(
*const_cast<nvinfer1::ITensor*>(tensor),
nvinfer1::DimsHW(padding[0].first, padding[1].first),
nvinfer1::DimsHW(padding[0].second, padding[1].second));
+ TFTRT_RETURN_ERROR_IF_NULLPTR(pad_layer, node_def.name());
padding = {{0, 0}, {0, 0}};
tensor = pad_layer->getOutput(0);
auto dim_after = tensor->getDimensions();
@@ -1112,6 +1204,7 @@ tensorflow::Status ConvertConv2DHelper(
nvinfer1::IConvolutionLayer* layer =
ctx.network()->addConvolution(*const_cast<nvinfer1::ITensor*>(tensor),
noutput, kernel_size, weights, biases);
+ TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
layer->setStride(stride);
layer->setPadding({padding[0].first, padding[1].first});
@@ -1126,6 +1219,7 @@ tensorflow::Status ConvertConv2DHelper(
if (data_format == "NHWC") {
// TODO(jie): transpose it back!
output_tensor = ctx.TransposeTensor(output_tensor, {0, 2, 3, 1});
+ TFTRT_RETURN_ERROR_IF_NULLPTR(output_tensor, node_def.name());
} else {
VLOG(2) << "NCHW !!!!";
}
@@ -1147,35 +1241,91 @@ tensorflow::Status ConvertConv2DHelper(
node_def.name());
}
+// Helper function converts input into tensor with shape specified by dims.
+bool PrepareTensorForShape(Converter& ctx, const TRT_TensorOrWeights& input,
+ const nvinfer1::Dims& dims,
+ const nvinfer1::ITensor** tensor) {
+ if (input.is_tensor()) {
+ if (DimsEqual(input.shape(), dims)) {
+ *tensor = input.tensor();
+ } else {
+ nvinfer1::IShuffleLayer* layer = ctx.network()->addShuffle(
+ *const_cast<nvinfer1::ITensor*>(input.tensor()));
+ if (layer != nullptr) {
+ layer->setReshapeDimensions(dims);
+ *tensor = layer->getOutput(0);
+ } else {
+ return false;
+ }
+ }
+ } else {
+#if NV_TENSORRT_MAJOR > 3
+ nvinfer1::IConstantLayer* layer =
+ ctx.network()->addConstant(dims, input.weights());
+ if (layer != nullptr) {
+ *tensor = layer->getOutput(0);
+ } else {
+ return false;
+ }
+#else
+ return false;
+#endif
+ }
+ return true;
+}
+
tensorflow::Status BinaryTensorOpTensor(
Converter& ctx, const tensorflow::NodeDef& node_def,
- const nvinfer1::ITensor* tensor_l, const nvinfer1::ITensor* tensor_r,
+ const TRT_TensorOrWeights& operand_l, const TRT_TensorOrWeights& operand_r,
std::vector<TRT_TensorOrWeights>* outputs) {
static const std::unordered_map<string, nvinfer1::ElementWiseOperation> ops{
{"Add", nvinfer1::ElementWiseOperation::kSUM},
{"Mul", nvinfer1::ElementWiseOperation::kPROD},
{"Sub", nvinfer1::ElementWiseOperation::kSUB},
{"Div", nvinfer1::ElementWiseOperation::kDIV},
+ {"RealDiv", nvinfer1::ElementWiseOperation::kDIV},
+ {"Minimum", nvinfer1::ElementWiseOperation::kMIN},
+ {"Maximum", nvinfer1::ElementWiseOperation::kMAX},
};
- // FIXME assume type matches input weights
+ const nvinfer1::ITensor* tensor_l;
+ const nvinfer1::ITensor* tensor_r;
+
+ nvinfer1::Dims dim_l;
+ nvinfer1::Dims dim_r;
+
+ if (!TensorRTGetBroadcastShape(operand_l.shape(), operand_l.is_tensor(),
+ operand_r.shape(), operand_r.is_tensor(),
+ &dim_l, &dim_r)) {
+ return tensorflow::errors::InvalidArgument(
+ "Binary op broadcast scheme not supported by TensorRT op: " +
+ node_def.op() + ", at: " + node_def.name());
+ }
+
+ TFTRT_RETURN_ERROR_IF_FALSE(
+ PrepareTensorForShape(ctx, operand_l, dim_l, &tensor_l), node_def.name());
+ TFTRT_RETURN_ERROR_IF_FALSE(
+ PrepareTensorForShape(ctx, operand_r, dim_r, &tensor_r), node_def.name());
+
// get trt type & shape
TFAttrs attrs(node_def);
// maybe this part has to be moved into the block of rsqrt later
nvinfer1::DataType dtype = attrs.get<nvinfer1::DataType>("T");
// check type consistency
- CHECK_EQ_TYPE(tensor_l->getType(), dtype);
- CHECK_EQ_TYPE(tensor_r->getType(), dtype);
+ TFTRT_CHECK_EQ_TYPE(tensor_l->getType(), dtype);
+ TFTRT_CHECK_EQ_TYPE(tensor_r->getType(), dtype);
auto op_pair = ops.find(node_def.op());
- if (op_pair == ops.end())
+ if (op_pair == ops.end()) {
return tensorflow::errors::Unimplemented(
- "binary op: " + node_def.op() +
- " not supported at: " + node_def.name());
+ "binary op: ", node_def.op(), " not supported at: ", node_def.name());
+ }
nvinfer1::IElementWiseLayer* layer = ctx.network()->addElementWise(
+ // TODO(aaroey): will tensor_l/tensor_r get modified?
*const_cast<nvinfer1::ITensor*>(tensor_l),
*const_cast<nvinfer1::ITensor*>(tensor_r), op_pair->second);
+ TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
nvinfer1::ITensor* output_tensor = layer->getOutput(0);
@@ -1202,7 +1352,7 @@ tensorflow::Status ConvertPlugin(Converter& ctx,
// passing attributes
// TODO(jie): support more general attribute
TFAttrs attrs(node_def);
- auto attr_key_vector = attrs.GetAllAttrKey();
+ auto attr_key_vector = attrs.GetAllAttrKeys();
for (auto attr_key : attr_key_vector) {
// TODO(jie): support only list of float for toy example here.
auto data = attrs.get<std::vector<float>>(attr_key);
@@ -1223,29 +1373,6 @@ tensorflow::Status ConvertPlugin(Converter& ctx,
return tensorflow::Status::OK();
}
-tensorflow::Status ConvertPlaceholder(
- Converter& ctx, const tensorflow::NodeDef& node_def,
- const std::vector<TRT_TensorOrWeights>& inputs,
- std::vector<TRT_TensorOrWeights>* outputs) {
- VLOG(2) << "Placeholder should have been replace already";
- return tensorflow::errors::Unimplemented("cannot convert Placeholder op");
- // OK this make sense since we are supposed to replace it with input
- TFAttrs attrs(node_def);
- nvinfer1::DataType dtype = attrs.get<nvinfer1::DataType>("dtype");
- nvinfer1::Dims dims = attrs.get<nvinfer1::Dims>("shape");
-
- dims.nbDims--;
- for (int i = 0; i < dims.nbDims; i++) dims.d[i] = dims.d[i + 1];
-
- nvinfer1::ITensor* output =
- ctx.network()->addInput(node_def.name().c_str(), dtype, dims);
- if (!output) {
- return tensorflow::errors::InvalidArgument("Failed to create Input layer");
- }
- outputs->push_back(TRT_TensorOrWeights(output));
- return tensorflow::Status::OK();
-}
-
tensorflow::Status ConvertConv2D(Converter& ctx,
const tensorflow::NodeDef& node_def,
const std::vector<TRT_TensorOrWeights>& inputs,
@@ -1271,65 +1398,64 @@ tensorflow::Status ConvertPool(Converter& ctx,
int h_index = 2;
int w_index = 3;
- auto data_format = attrs.get<string>("data_format");
+ const auto data_format = attrs.get<string>("data_format");
if (data_format == "NHWC") {
h_index = 1;
w_index = 2;
tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
{0, 3, 1, 2});
- } else {
- VLOG(2) << "NCHW !!!!";
+ TFTRT_RETURN_ERROR_IF_NULLPTR(tensor, node_def.name());
}
+
nvinfer1::PoolingType type;
- // TODO(jie): support other pooling type
- if (node_def.op() == "MaxPool")
+ if (node_def.op() == "MaxPool") {
type = nvinfer1::PoolingType::kMAX;
- else if (node_def.op() == "AvgPool")
+ } else if (node_def.op() == "AvgPool") {
type = nvinfer1::PoolingType::kAVERAGE;
- else
- return tensorflow::errors::Unimplemented("Only supports Max pool");
+ } else {
+ return tensorflow::errors::Unimplemented("Unsupported pool type: ",
+ node_def.op());
+ }
- // TODO(jie): NCHW
- auto tf_stride = attrs.get<std::vector<int>>("strides");
- nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]);
+ const auto tf_stride = attrs.get<std::vector<int>>("strides");
+ const nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]);
- auto tf_kernel = attrs.get<std::vector<int>>("ksize");
- nvinfer1::DimsHW ksize(tf_kernel[h_index], tf_kernel[w_index]);
+ const auto tf_kernel = attrs.get<std::vector<int>>("ksize");
+ const nvinfer1::DimsHW ksize(tf_kernel[h_index], tf_kernel[w_index]);
auto tensor_dim = tensor->getDimensions();
std::vector<std::pair<int, int>> padding;
- // TODO(jie): padding.
- if (attrs.get<string>("padding") == "SAME") {
+ const string padding_type = attrs.get<string>("padding");
+ if (padding_type == "SAME") {
// This is NCHW tensor with no batch dimension.
// 1 -> h
// 2 -> w
padding = CreateSamePadding(
stride, ksize,
{static_cast<int>(tensor_dim.d[1]), static_cast<int>(tensor_dim.d[2])});
- } else if (attrs.get<string>("padding") == "VALID") {
- // No padding for valid padding here
- VLOG(2) << "No padding added for VALID padding in pool" << node_def.name();
+ } else if (padding_type == "VALID") {
padding = {{0, 0}, {0, 0}};
} else {
- return tensorflow::errors::Unimplemented(
- "Current MaxPool cannot support padding other than SAME");
+ return tensorflow::errors::Unimplemented("Unsupported padding type: ",
+ padding_type);
}
if (padding[0].first != padding[0].second ||
padding[1].first != padding[1].second) {
- // TODO(jie): handle asymmetric padding
VLOG(2) << "Padding!!!: " << padding[0].first << padding[0].second
<< padding[1].first << padding[1].second;
auto pad_layer = ctx.network()->addPadding(
*const_cast<nvinfer1::ITensor*>(tensor),
nvinfer1::DimsHW(padding[0].first, padding[1].first),
nvinfer1::DimsHW(padding[0].second, padding[1].second));
+ TFTRT_RETURN_ERROR_IF_NULLPTR(pad_layer, node_def.name());
padding = {{0, 0}, {0, 0}};
tensor = pad_layer->getOutput(0);
}
nvinfer1::IPoolingLayer* layer = ctx.network()->addPooling(
*const_cast<nvinfer1::ITensor*>(tensor), type, ksize);
+ TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
layer->setStride(stride);
layer->setPadding({padding[0].first, padding[1].first});
@@ -1337,10 +1463,8 @@ tensorflow::Status ConvertPool(Converter& ctx,
nvinfer1::ITensor* output_tensor = layer->getOutput(0);
if (data_format == "NHWC") {
- // TODO(jie): transpose it back!
output_tensor = ctx.TransposeTensor(output_tensor, {0, 2, 3, 1});
- } else {
- VLOG(2) << "NCHW !!!!";
+ TFTRT_RETURN_ERROR_IF_NULLPTR(output_tensor, node_def.name());
}
outputs->push_back(TRT_TensorOrWeights(output_tensor));
return tensorflow::Status::OK();
@@ -1353,6 +1477,7 @@ tensorflow::Status ConvertActivation(
const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
nvinfer1::IActivationLayer* layer = ctx.network()->addActivation(
*const_cast<nvinfer1::ITensor*>(tensor), nvinfer1::ActivationType::kRELU);
+ TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
nvinfer1::ITensor* output_tensor = layer->getOutput(0);
outputs->push_back(TRT_TensorOrWeights(output_tensor));
return tensorflow::Status::OK();
@@ -1363,40 +1488,61 @@ tensorflow::Status ConvertScale(Converter& ctx,
const std::vector<TRT_TensorOrWeights>& inputs,
std::vector<TRT_TensorOrWeights>* outputs) {
if (inputs.size() != 2 || !inputs.at(0).is_tensor() ||
- !inputs.at(1).is_weights())
+ !inputs.at(1).is_weights()) {
return tensorflow::errors::Unimplemented(
- "Only supports tensor op weight for now, at " + node_def.name());
- // Implement tensor binaryOp weight [channel wise] for now;
- const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
+ "ConvertScale only supports tensor<op>weight: ", node_def.name());
+ }
+ const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
TRT_ShapedWeights weights = inputs.at(1).weights();
if (ctx.isFP16()) {
weights = ConvertFP32ToFP16(ctx, inputs.at(1).weights());
}
TRT_ShapedWeights empty_weights(weights.type_);
-
TFAttrs attrs(node_def);
- // Transpose NHWC
- auto data_format = attrs.get<string>("data_format");
+ const auto data_format = attrs.get<string>("data_format");
+ int channel_index;
+ const auto dims = tensor->getDimensions();
if (data_format == "NHWC") {
- tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
- {0, 3, 1, 2});
- // TODO(jie): transpose it
+ // 1). NHWC is really N+C
+ channel_index = dims.nbDims - 1; // batch dimension is implicit here!
} else {
- VLOG(2) << "NCHW !!!!";
+ // 2). NCHW is really N+CHW
+ channel_index = dims.nbDims - 3; // batch dimension is implicit here!
}
- auto dims = tensor->getDimensions();
- VLOG(2) << "tensor dimensions: " << dims.nbDims;
- for (int i = 0; i < dims.nbDims; i++) {
- VLOG(2) << "i: " << dims.d[i];
+ nvinfer1::Permutation permutation;
+ for (int32_t i = 0; i < dims.nbDims; ++i) {
+ permutation.order[i] = i;
}
- dims = weights.shape_;
- VLOG(2) << "tensor dimensions: " << dims.nbDims;
- for (int i = 0; i < dims.nbDims; i++) {
- VLOG(2) << "i: " << dims.d[i];
+
+ if (channel_index >= 0) {
+ permutation.order[0] = channel_index;
+ permutation.order[channel_index] = 0;
+ } else {
+ return tensorflow::errors::Unimplemented(
+ "TFTRT::BiasAdd cannot apply on batch dimension, at ", node_def.name());
+ }
+
+ // TensorRT addScale requires input to be of rank 3, we need to apply
+ // transpose as well as reshape
+ if (channel_index != 0 || dims.nbDims != 3) {
+ nvinfer1::IShuffleLayer* shuffle_layer =
+ ctx.network()->addShuffle(*const_cast<nvinfer1::ITensor*>(tensor));
+ TFTRT_RETURN_ERROR_IF_NULLPTR(shuffle_layer, node_def.name());
+ nvinfer1::Dims reshape_dims;
+ reshape_dims.nbDims = 3;
+ reshape_dims.d[0] = 0; // 0 copy from the input
+ reshape_dims.d[1] = dims.nbDims >= 2 ? 0 : 1; // 0 copy from the input
+ reshape_dims.d[2] = dims.nbDims >= 3 ? -1 : 1; // -1 infer from the rest
+ if (channel_index != 0) {
+ // maybe we do not need this check. concerned about TRT optimization
+ shuffle_layer->setFirstTranspose(permutation);
+ }
+ shuffle_layer->setReshapeDimensions(reshape_dims);
+ tensor = shuffle_layer->getOutput(0);
}
nvinfer1::ScaleMode mode = nvinfer1::ScaleMode::kCHANNEL;
@@ -1407,14 +1553,26 @@ tensorflow::Status ConvertScale(Converter& ctx,
nvinfer1::IScaleLayer* layer =
ctx.network()->addScale(*const_cast<nvinfer1::ITensor*>(tensor), mode,
weights, empty_weights, empty_weights);
+ TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
nvinfer1::ITensor* output_tensor = layer->getOutput(0);
- if (data_format == "NHWC") {
- // TODO(jie): transpose it back!
- output_tensor = ctx.TransposeTensor(output_tensor, {0, 2, 3, 1});
- } else {
- VLOG(2) << "NCHW !!!!";
+
+ // restore transpose & reshape
+ if (channel_index != 0 || dims.nbDims != 3) {
+ nvinfer1::IShuffleLayer* shuffle_layer = ctx.network()->addShuffle(
+ *const_cast<nvinfer1::ITensor*>(output_tensor));
+ TFTRT_RETURN_ERROR_IF_NULLPTR(shuffle_layer, node_def.name());
+ nvinfer1::Dims reshape_dims = dims;
+ int tmp = reshape_dims.d[channel_index];
+ reshape_dims.d[channel_index] = reshape_dims.d[0];
+ reshape_dims.d[0] = tmp;
+ shuffle_layer->setReshapeDimensions(reshape_dims);
+ if (channel_index != 0) {
+ shuffle_layer->setSecondTranspose(permutation);
+ }
+ output_tensor = shuffle_layer->getOutput(0);
}
+
outputs->push_back(TRT_TensorOrWeights(output_tensor));
return tensorflow::Status::OK();
}
@@ -1431,11 +1589,13 @@ tensorflow::Status ConvertConst(Converter& ctx,
// Create shaped weights as output
tensorflow::Tensor tensor;
- if (!tensor.FromProto(weights_tensor))
- return tensorflow::errors::Internal("Cannot parse weight tensor proto: " +
+ if (!tensor.FromProto(weights_tensor)) {
+ return tensorflow::errors::Internal("Cannot parse weight tensor proto: ",
node_def.name());
+ }
TRT_ShapedWeights weights(dtype);
+ // TODO(aaroey): we should choose the array using dtype and shape.
if (!weights_tensor.float_val().empty()) {
VLOG(2) << "SCALAR!!!" << node_def.name();
nvinfer1::Dims scalar_shape;
@@ -1443,22 +1603,16 @@ tensorflow::Status ConvertConst(Converter& ctx,
VLOG(2) << "dimensions: " << tensor.dims();
VLOG(2) << "size: " << weights_tensor.float_val_size();
scalar_shape = GetTensorShape(tensor);
+ VLOG(2) << "details: ";
for (int i = 0; i < scalar_shape.nbDims; i++)
VLOG(2) << scalar_shape.d[i];
- if (GetShapeSize(scalar_shape) != weights_tensor.float_val_size()) {
- if (weights_tensor.float_val_size() == 1 ||
- scalar_shape.d[0] == weights_tensor.float_val_size()) {
- scalar_shape.nbDims = 1;
- // no dimension provided. flatten it
- scalar_shape.d[0] = weights_tensor.float_val_size();
- scalar_shape.type[0] = nvinfer1::DimensionType::kSPATIAL;
- } else {
- LOG(WARNING) << "Broadcast on weights only supports kCHANNEL and"
- << " kUNIFORM, at: " << node_def.name();
- string err_str("Broadcast method is not supported for '");
- StrAppend(&err_str, node_def.name(), "' of type ", node_def.op());
- return tensorflow::errors::InvalidArgument(err_str);
- }
+ if (GetShapeSize(scalar_shape) != weights_tensor.float_val_size() &&
+ weights_tensor.float_val_size() != 1) {
+ LOG(ERROR) << "Broadcast on weights only supports kCHANNEL and"
+ << " kUNIFORM, at: " << node_def.name();
+ string err_str("Broadcast method is not supported for '");
+ StrAppend(&err_str, node_def.name(), "' of type ", node_def.op());
+ return tensorflow::errors::InvalidArgument(err_str);
}
} else {
VLOG(2) << "Dimensions: " << tensor.dims();
@@ -1468,39 +1622,42 @@ tensorflow::Status ConvertConst(Converter& ctx,
scalar_shape.type[0] = nvinfer1::DimensionType::kSPATIAL;
for (int i = 1; i < nvinfer1::Dims::MAX_DIMS; i++) {
scalar_shape.d[i] = 0;
- scalar_shape.type[i] = nvinfer1::DimensionType::kSPATIAL;
}
}
+ // TODO(aaroey): use GetShapeSize().
size_t len_data = tensorflow::DataTypeSize(dtype);
for (int i = 0; i < scalar_shape.nbDims; i++) len_data *= scalar_shape.d[i];
ctx.weight_store()->store_.push_back(std::vector<uint8_t>(len_data));
void* dst = static_cast<void*>(&(ctx.weight_store()->store_.back()[0]));
- std::vector<float> tensor_data(
- weights_tensor.float_val().begin(),
- weights_tensor.float_val()
- .end()); // make a local copy first to flatten
- memcpy(dst, tensor_data.data(), len_data); // store into weight store
+ if (weights_tensor.float_val_size() == 1) {
+ std::fill_n((float*)dst, GetShapeSize(scalar_shape),
+ *weights_tensor.float_val().begin());
+ } else {
+ // TODO(aaroey): get rid of this copy as RepeatedField is always
+ // contiguous make a local copy first to flatten doesn't have to be
+ // contiguous
+ std::vector<float> tensor_data(weights_tensor.float_val().begin(),
+ weights_tensor.float_val().end());
+ memcpy(dst, tensor_data.data(), len_data); // store into weight store
+ }
+ VLOG(2) << "create shape details: ";
+ for (int i = 0; i < scalar_shape.nbDims; i++) VLOG(2) << scalar_shape.d[i];
weights = TRT_ShapedWeights(dtype, dst, scalar_shape);
} else if (!weights_tensor.int_val().empty()) {
+ // TODO(aaroey): this is very similar to the above code for float, merge
+ // them.
VLOG(2) << "int!!!" << node_def.name();
nvinfer1::Dims scalar_shape;
if (tensor.dims() > 0) {
VLOG(2) << "dimensions: " << tensor.dims();
scalar_shape = GetTensorShape(tensor);
- if (GetShapeSize(scalar_shape) != weights_tensor.int_val_size()) {
- if (weights_tensor.int_val_size() == 1 ||
- scalar_shape.d[0] == weights_tensor.int_val_size()) {
- scalar_shape.nbDims = 1;
- // no dimension provided. flatten it
- scalar_shape.d[0] = weights_tensor.int_val_size();
- scalar_shape.type[0] = nvinfer1::DimensionType::kSPATIAL;
- } else {
- LOG(WARNING) << "Broadcast on weights only supports kCHANNEL and"
- << " kUNIFORM, at: " << node_def.name();
- string err_str("Broadcast method is not supported for '");
- StrAppend(&err_str, node_def.name(), "' of type ", node_def.op());
- return tensorflow::errors::InvalidArgument(err_str);
- }
+ if (GetShapeSize(scalar_shape) != weights_tensor.int_val_size() &&
+ weights_tensor.int_val_size() != 1) {
+ LOG(WARNING) << "Broadcast on weights only supports kCHANNEL and"
+ << " kUNIFORM, at: " << node_def.name();
+ string err_str("Broadcast method is not supported for '");
+ StrAppend(&err_str, node_def.name(), "' of type ", node_def.op());
+ return tensorflow::errors::InvalidArgument(err_str);
}
} else {
VLOG(2) << "dimensions: " << tensor.dims();
@@ -1513,23 +1670,30 @@ tensorflow::Status ConvertConst(Converter& ctx,
scalar_shape.type[i] = nvinfer1::DimensionType::kSPATIAL;
}
}
- // we should not have converted //if (ctx.isFP16()) {
+ // we should not have converted
size_t len_data = tensorflow::DataTypeSize(dtype);
for (int i = 0; i < scalar_shape.nbDims; i++) len_data *= scalar_shape.d[i];
size_t len_tensor = weights_tensor.int_val_size() * sizeof(int32);
len_data = std::max(len_data, len_tensor);
ctx.weight_store()->store_.push_back(std::vector<uint8_t>(len_data));
void* dst = static_cast<void*>(&(ctx.weight_store()->store_.back()[0]));
- std::vector<int32> tensor_data(
- weights_tensor.int_val().begin(),
- weights_tensor.int_val().end()); // make a local copy first to flatten
- // doesn't have to be contigous
- memcpy(dst, tensor_data.data(), len_tensor); // store into weight store
+ if (weights_tensor.int_val_size() == 1) {
+ std::fill_n((int*)dst, GetShapeSize(scalar_shape),
+ *weights_tensor.int_val().begin());
+ } else {
+ // TODO(aaroey): get rid of this copy as RepeatedField is always
+ // contiguous make a local copy first to flatten doesn't have to be
+ // contiguous
+ std::vector<int32> tensor_data(weights_tensor.int_val().begin(),
+ weights_tensor.int_val().end());
+ memcpy(dst, tensor_data.data(), len_tensor); // store into weight store
+ }
weights = TRT_ShapedWeights(dtype, dst, scalar_shape);
} else if (!weights_tensor.tensor_content().empty()) {
- // obsolete method.
- // After optimization path, we do not see weights in this format.
- // fp16 conversion technically should be needed here.
+ // obsolete method.
+ // After optimization path, we do not see weights in this format.
+ // TODO(aaroey): why?
+ // fp16 conversion technically should be needed here.
VLOG(2) << "TENSOR!!!" << node_def.name();
const auto& content = weights_tensor.tensor_content();
@@ -1543,8 +1707,8 @@ tensorflow::Status ConvertConst(Converter& ctx,
content, static_cast<char*>(const_cast<void*>(weights.GetValues())));
}
} else {
- return tensorflow::errors::Unimplemented(
- "Not supported constant type, at " + node_def.name());
+ return tensorflow::errors::Unimplemented("Not supported constant type, at ",
+ node_def.name());
}
// Pass the output
outputs->push_back(TRT_TensorOrWeights(weights));
@@ -1563,96 +1727,144 @@ tensorflow::Status ConvertBinary(Converter& ctx,
const tensorflow::NodeDef& node_def,
const std::vector<TRT_TensorOrWeights>& inputs,
std::vector<TRT_TensorOrWeights>* outputs) {
- if (inputs.size() != 2)
+ if (inputs.size() != 2) {
return tensorflow::errors::FailedPrecondition(
- "Binary ops require two tensor input, at " + node_def.name());
-
- if (inputs.at(0).is_weights() && inputs.at(1).is_weights())
- return ConstantFoldBinary(ctx, node_def, inputs, outputs);
-
- if (inputs.at(0).is_tensor() && inputs.at(1).is_weights())
- return BinaryTensorOpWeight(ctx, node_def, inputs.at(0).tensor(),
- inputs.at(1).weights(), outputs);
+ "Binary ops require two tensor input, at ", node_def.name());
+ }
- if (inputs.at(0).is_weights() && inputs.at(1).is_tensor())
- return BinaryTensorOpWeight(ctx, node_def, inputs.at(1).tensor(),
- inputs.at(0).weights(), outputs);
+ // Constant folding should have been done by TensorFlow
- if (inputs.at(0).is_tensor() && inputs.at(1).is_tensor())
- return BinaryTensorOpTensor(ctx, node_def, inputs.at(0).tensor(),
- inputs.at(1).tensor(), outputs);
+ if (inputs.at(0).is_weights() && inputs.at(1).is_weights()) {
+ return tensorflow::errors::Unimplemented(
+ "Constant folding is falled back to TensorFlow, binary op received "
+ "both input as constant at: ",
+ node_def.name());
+ }
- return tensorflow::errors::Unknown("Binary op input error, at " +
- node_def.name());
+ // Try to convert into Scale layer first (for better performance)
+ // Since scale layer supports restricted broadcast policy and op types, we
+ // allow failure and try to handle it through Elementwise op
+ // (BinaryTensorOpTensor)
+ Status status = tensorflow::Status::OK();
+ if (inputs.at(0).is_tensor() && inputs.at(1).is_weights()) {
+ status = BinaryTensorOpWeight(ctx, node_def, inputs.at(0).tensor(),
+ inputs.at(1).weights(), false, outputs);
+ } else if (inputs.at(0).is_weights() && inputs.at(1).is_tensor()) {
+ status = BinaryTensorOpWeight(ctx, node_def, inputs.at(1).tensor(),
+ inputs.at(0).weights(), true, outputs);
+#if NV_TENSORRT_MAJOR == 3
+ } else {
+#else
+ }
+ if (inputs.at(0).is_tensor() && inputs.at(1).is_tensor() || !status.ok()) {
+#endif
+ status = BinaryTensorOpTensor(ctx, node_def, inputs.at(0), inputs.at(1),
+ outputs);
+ }
+ return status;
}
tensorflow::Status ConvertUnary(Converter& ctx,
const tensorflow::NodeDef& node_def,
const std::vector<TRT_TensorOrWeights>& inputs,
std::vector<TRT_TensorOrWeights>* outputs) {
- if (inputs.size() != 1)
+ static const std::unordered_map<string, nvinfer1::UnaryOperation> ops{
+ {"Neg", nvinfer1::UnaryOperation::kNEG},
+ {"Exp", nvinfer1::UnaryOperation::kEXP},
+ {"Log", nvinfer1::UnaryOperation::kLOG},
+ {"Sqrt", nvinfer1::UnaryOperation::kSQRT},
+ {"Abs", nvinfer1::UnaryOperation::kABS},
+ {"Reciprocal", nvinfer1::UnaryOperation::kRECIP},
+ };
+
+ if (inputs.size() != 1) {
return tensorflow::errors::FailedPrecondition(
- "Unary ops require single tensor input, at " + node_def.name());
+ "Unary ops require single tensor input, at ", node_def.name());
+ }
- if (inputs.at(0).is_weights())
- return ConstantFoldUnary(ctx, node_def, inputs, outputs);
- else if (inputs.at(0).is_tensor())
+#if NV_TENSORRT_MAJOR == 3
+ if (inputs.at(0).is_weights()) {
return tensorflow::errors::Unimplemented(
- "Unary op for tensor not supported, at " + node_def.name());
+ "Constant folding for unary op is not supported", node_def.name());
+ }
+#endif
+
+ // TODO(jie): check type
+ const nvinfer1::ITensor* tensor;
+ TFTRT_RETURN_ERROR_IF_FALSE(
+ PrepareTensorForShape(ctx, inputs.at(0), inputs.at(0).shape(), &tensor),
+ node_def.name());
+
+ nvinfer1::IUnaryLayer* layer;
+ if (node_def.op() == "Rsqrt") {
+ layer = ctx.network()->addUnary(*const_cast<nvinfer1::ITensor*>(tensor),
+ nvinfer1::UnaryOperation::kSQRT);
+ TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
+ tensor = layer->getOutput(0);
+ layer = ctx.network()->addUnary(*const_cast<nvinfer1::ITensor*>(tensor),
+ nvinfer1::UnaryOperation::kRECIP);
+ } else if (ops.count(node_def.op()) != 0) {
+ layer = ctx.network()->addUnary(*const_cast<nvinfer1::ITensor*>(tensor),
+ ops.at(node_def.op()));
+ } else {
+ return tensorflow::errors::InvalidArgument(
+ "Binary op: ", node_def.op(), " not supported, at ", node_def.name());
+ }
- return tensorflow::errors::Unknown("Binary op input error, at " +
- node_def.name());
+ TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
+ nvinfer1::ITensor* output_tensor = layer->getOutput(0);
+ outputs->push_back(TRT_TensorOrWeights(output_tensor));
+ return tensorflow::Status::OK();
}
-tensorflow::Status ConvertReduce(Converter& ctx,
- const tensorflow::NodeDef& node_def,
- const std::vector<TRT_TensorOrWeights>& inputs,
- std::vector<TRT_TensorOrWeights>* outputs) {
+#if NV_TENSORRT_MAJOR == 3
+tensorflow::Status ConvertReducePool(
+ Converter& ctx, const tensorflow::NodeDef& node_def,
+ const std::vector<TRT_TensorOrWeights>& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
if (inputs.size() != 2 || !inputs.at(0).is_tensor() ||
- !inputs.at(1).is_weights())
+ !inputs.at(1).is_weights()) {
return tensorflow::errors::InvalidArgument(
- "Input expects tensor and weights, at" + node_def.name());
+ "Input expects tensor and weights, at", node_def.name());
+ }
// Implement tensor binaryOp weight [channel wise] for now;
const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
- auto dims = tensor->getDimensions();
+ const auto dims = tensor->getDimensions();
// Restore implicit batch dimension
- int nb_dims = dims.nbDims + 1;
+ const int nb_dims = dims.nbDims + 1;
TRT_ShapedWeights index_list = inputs.at(1).weights();
-
TFAttrs attrs(node_def);
- // TODO(jie): handle data type.
- // Index type here is done through TF type, so I can leverage their
- // EnumToDataType for my cast
auto index_type = attrs.get<tensorflow::DataType>("Tidx");
// Only expect to handle INT32 as attributes for now
- if (index_type != tensorflow::DataType::DT_INT32)
+ if (index_type != tensorflow::DataType::DT_INT32) {
return tensorflow::errors::Unimplemented("Tidx supports only DT_INT32");
- auto index_list_data =
+ }
+ const auto index_list_data =
static_cast<int*>(const_cast<void*>(index_list.GetValues()));
- // Hack warning: have to fall back to pool layer since reduce is not in public
- // TRT yet.
- if (nb_dims != 4)
+ if (nb_dims != 4) {
return tensorflow::errors::InvalidArgument(
- "TRT only support reduce on 4 dimensional tensors, at" +
+ "TRT only support reduce on 4 dimensional tensors, at",
node_def.name());
- if (index_list.count() > 2)
+ }
+ if (index_list.count() > 2) {
return tensorflow::errors::InvalidArgument(
- "TRT cannot support reduce on more than 2 dimensions, at" +
+ "TRT cannot support reduce on more than 2 dimensions, at",
node_def.name());
+ }
std::set<int> idx_set;
// We cannot operate on Channel. permutation flag used to transpose tensor
int permuted_index = -1;
for (int i = 0; i < index_list.count(); i++) {
- if (index_list_data[i] == 0)
- return tensorflow::errors::InvalidArgument("TRT cannot reduce at 0, at" +
+ if (index_list_data[i] == 0) {
+ return tensorflow::errors::InvalidArgument("TRT cannot reduce at 0, at",
node_def.name());
+ }
if (index_list_data[i] == 1) permuted_index = 1;
-
idx_set.emplace(index_list_data[i]);
}
@@ -1673,6 +1885,7 @@ tensorflow::Status ConvertReduce(Converter& ctx,
// Apply permutation before extracting dimension for pool_kernel
tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
permutation_order);
+ TFTRT_RETURN_ERROR_IF_NULLPTR(tensor, node_def.name());
}
// Apply permutation before extracting dimension for pool_kernel
@@ -1685,34 +1898,104 @@ tensorflow::Status ConvertReduce(Converter& ctx,
nvinfer1::IPoolingLayer* layer =
ctx.network()->addPooling(*const_cast<nvinfer1::ITensor*>(tensor),
nvinfer1::PoolingType::kAVERAGE, pool_kernel);
+ TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
output_tensor = layer->getOutput(0);
} else {
- return tensorflow::errors::Unimplemented(
- "Op not supported " + node_def.op() + " , at " + node_def.name());
+ return tensorflow::errors::Unimplemented("Op not supported ", node_def.op(),
+ " , at ", node_def.name());
}
if (permuted_index != -1) {
// Apply permutation before extracting dimension for pool_kernel
output_tensor = ctx.TransposeTensor(
const_cast<nvinfer1::ITensor*>(output_tensor), permutation_order);
+ TFTRT_RETURN_ERROR_IF_NULLPTR(output_tensor, node_def.name());
}
outputs->push_back(TRT_TensorOrWeights(output_tensor));
return tensorflow::Status::OK();
}
+#elif NV_TENSORRT_MAJOR > 3
+tensorflow::Status ConvertReduce(Converter& ctx,
+ const tensorflow::NodeDef& node_def,
+ const std::vector<TRT_TensorOrWeights>& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ if (inputs.size() != 2 || !inputs.at(0).is_tensor() ||
+ !inputs.at(1).is_weights()) {
+ return tensorflow::errors::InvalidArgument(
+ "Input expects tensor and weights, at", node_def.name());
+ }
+
+ const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
+ TRT_ShapedWeights index_list = inputs.at(1).weights();
+
+ TFAttrs attrs(node_def);
+ auto index_type = attrs.get<tensorflow::DataType>("Tidx");
+
+ // Only expect to handle INT32 as attributes for now
+ if (index_type != tensorflow::DataType::DT_INT32) {
+ return tensorflow::errors::Unimplemented("Tidx supports only DT_INT32");
+ }
+
+ const auto keep_dims = attrs.get<bool>("keep_dims");
+ auto index_list_data =
+ static_cast<int*>(const_cast<void*>(index_list.GetValues()));
+
+ int axes = 0;
+ if (index_list.count() == 0) {
+ return tensorflow::errors::InvalidArgument(
+ "TRT cannot support reduce on all (batch) dimensions, at",
+ node_def.name());
+ } else {
+ for (int i = 0; i < index_list.count(); i++) {
+ if (index_list_data[i] == 0) {
+ return tensorflow::errors::InvalidArgument(
+ "TRT cannot reduce at batch dimension, at", node_def.name());
+ }
+ axes |= (1 << (index_list_data[i] - 1));
+ }
+ }
+
+ nvinfer1::ReduceOperation reduce_operation;
+ if (node_def.op() == "Sum") {
+ reduce_operation = nvinfer1::ReduceOperation::kSUM;
+ } else if (node_def.op() == "Prod") {
+ reduce_operation = nvinfer1::ReduceOperation::kPROD;
+ } else if (node_def.op() == "Max") {
+ reduce_operation = nvinfer1::ReduceOperation::kMAX;
+ } else if (node_def.op() == "Min") {
+ reduce_operation = nvinfer1::ReduceOperation::kMIN;
+ } else if (node_def.op() == "Mean") {
+ reduce_operation = nvinfer1::ReduceOperation::kAVG;
+ } else {
+ return tensorflow::errors::Unimplemented("Op not supported ", node_def.op(),
+ " , at ", node_def.name());
+ }
+
+ nvinfer1::ILayer* layer =
+ ctx.network()->addReduce(*const_cast<nvinfer1::ITensor*>(tensor),
+ reduce_operation, axes, keep_dims);
+ TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
+
+ outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0)));
+ return tensorflow::Status::OK();
+}
+#endif
tensorflow::Status ConvertPad(Converter& ctx,
const tensorflow::NodeDef& node_def,
const std::vector<TRT_TensorOrWeights>& inputs,
std::vector<TRT_TensorOrWeights>* outputs) {
+ // TODO(aaroey): make a routine for this check and reuse it.
if (inputs.size() != 2 || !inputs.at(0).is_tensor() ||
- !inputs.at(1).is_weights())
+ !inputs.at(1).is_weights()) {
return tensorflow::errors::InvalidArgument(
- "Input expects tensor and weights, at" + node_def.name());
+ "Input expects tensor and weights, at", node_def.name());
+ }
// Implement tensor binaryOp weight [channel wise] for now;
const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
- auto dims = tensor->getDimensions();
+ const auto dims = tensor->getDimensions();
// Restore implicit batch dimension
- int nb_dims = dims.nbDims + 1;
+ const int nb_dims = dims.nbDims + 1;
TRT_ShapedWeights pads = inputs.at(1).weights();
@@ -1722,21 +2005,24 @@ tensorflow::Status ConvertPad(Converter& ctx,
auto padding_type = attrs.get<tensorflow::DataType>("Tpaddings");
// TODO(jie): handle data type conversion for TRT?
- if (pads.shape_.d[0] != nb_dims || pads.shape_.d[1] != 2)
+ if (pads.shape_.d[0] != nb_dims || pads.shape_.d[1] != 2) {
return tensorflow::errors::InvalidArgument(
- "Pad only supports explicit padding on 4 dimensional tensor, at " +
+ "Pad only supports explicit padding on 4 dimensional tensor, at ",
node_def.name());
+ }
// Only expect to handle INT32 as attributes for now
- if (padding_type != tensorflow::DataType::DT_INT32)
+ if (padding_type != tensorflow::DataType::DT_INT32) {
return tensorflow::errors::Unimplemented(
"Tpaddings supports only DT_INT32");
+ }
auto pad_data = static_cast<int*>(const_cast<void*>(pads.GetValues()));
std::vector<int32_t> pad_index;
for (int i = 0; i < nb_dims; i++) {
- if (pad_data[2 * i] != 0 || pad_data[2 * i + 1] != 0)
+ if (pad_data[2 * i] != 0 || pad_data[2 * i + 1] != 0) {
pad_index.push_back(i);
+ }
}
// No padding at all, we should exit
@@ -1746,20 +2032,23 @@ tensorflow::Status ConvertPad(Converter& ctx,
}
// Only supports padding on less than 2 axis GIE-2579
- if (pad_index.size() > 2)
+ if (pad_index.size() > 2) {
return tensorflow::errors::InvalidArgument(
"Padding layer does not support padding on > 2");
+ }
// Padding on batch dimension is not supported
- if (pad_index[0] == 0)
+ if (pad_index[0] == 0) {
return tensorflow::errors::InvalidArgument(
"Padding layer does not support padding on batch dimension");
+ }
// Not doing the legit thing here. ignoring padding on dim 1 and 3;
// TODO(jie): implement pad as uff parser
- if (pad_index.size() == 2 && pad_index[0] == 0 && pad_index[1] == 3)
+ if (pad_index.size() == 2 && pad_index[0] == 0 && pad_index[1] == 3) {
return tensorflow::errors::Unimplemented(
"Padding layer does not support padding on dimension 1 and 3 yet");
+ }
bool legit_pad = true;
nvinfer1::DimsHW pre_padding(0, 0);
@@ -1770,6 +2059,7 @@ tensorflow::Status ConvertPad(Converter& ctx,
legit_pad = false;
tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
{0, 3, 2, 1});
+ TFTRT_RETURN_ERROR_IF_NULLPTR(tensor, node_def.name());
permuted_pad_index[0] = 3;
}
@@ -1786,11 +2076,14 @@ tensorflow::Status ConvertPad(Converter& ctx,
nvinfer1::IPaddingLayer* layer = ctx.network()->addPadding(
*const_cast<nvinfer1::ITensor*>(tensor), pre_padding, post_padding);
+ TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
nvinfer1::ITensor* output_tensor = layer->getOutput(0);
- if (!legit_pad)
+ if (!legit_pad) {
output_tensor = ctx.TransposeTensor(
const_cast<nvinfer1::ITensor*>(output_tensor), {0, 3, 2, 1});
+ TFTRT_RETURN_ERROR_IF_NULLPTR(output_tensor, node_def.name());
+ }
outputs->push_back(TRT_TensorOrWeights(output_tensor));
return tensorflow::Status::OK();
@@ -1803,9 +2096,10 @@ tensorflow::Status ConvertConcat(Converter& ctx,
// not including the last input (axis) here
int input_size = static_cast<int>(inputs.size()) - 1;
- if (!inputs.at(0).is_tensor())
+ if (!inputs.at(0).is_tensor()) {
return tensorflow::errors::InvalidArgument(
- "Concat in TRT support only Tensor input, at " + node_def.name());
+ "Concat in TRT support only Tensor input, at ", node_def.name());
+ }
// We are retrieving the axis
TRT_ShapedWeights axis = inputs.at(input_size).weights();
@@ -1816,8 +2110,8 @@ tensorflow::Status ConvertConcat(Converter& ctx,
// TODO(jie): handle data type
// Only expect to handle INT32 as index attributes for now
if (index_type != tensorflow::DataType::DT_INT32)
- return tensorflow::errors::Unimplemented(
- "Tidx supports only DT_INT32, at " + node_def.name());
+ return tensorflow::errors::Unimplemented("Tidx supports only DT_INT32, at ",
+ node_def.name());
int index = *(static_cast<int*>(const_cast<void*>(axis.GetValues())));
@@ -1825,23 +2119,29 @@ tensorflow::Status ConvertConcat(Converter& ctx,
auto dim = inputs.at(0).tensor()->getDimensions();
// dimension check
- if (index > dim.nbDims + 1)
+ if (index > dim.nbDims + 1) {
return tensorflow::errors::InvalidArgument(
- "Concatenate on axis out of dimension range, at " + node_def.name());
-
- if (index == 0)
+ "Concatenate on axis out of dimension range, at ", node_def.name());
+ }
+ if (index == 0) {
return tensorflow::errors::InvalidArgument(
- "Concatenate on batch dimension not supported, at " + node_def.name());
+ "Concatenate on batch dimension not supported, at ", node_def.name());
+ }
+ if (index < 0) {
+ index = dim.nbDims + index + 1;
+ }
+#if NV_TENSORRT_MAJOR == 3
// incase we need permutation;
std::vector<int> permutation_order(dim.nbDims + 1);
for (int i = 0; i < dim.nbDims + 1; i++) permutation_order[i] = i;
if (index != 1) {
- permutation_order[1] = index - 1;
- permutation_order[index - 1] = 1;
+ permutation_order[1] = index;
+ permutation_order[index] = 1;
}
+#endif
std::vector<nvinfer1::ITensor const*> inputs_vec;
// Shap chack (all input tensor should have same shape)
@@ -1849,24 +2149,28 @@ tensorflow::Status ConvertConcat(Converter& ctx,
for (int i = 0; i < input_size; i++) {
auto tensor_i = inputs.at(i).tensor();
auto dim_i = tensor_i->getDimensions();
- if (dim_i.nbDims != dim.nbDims)
+ if (dim_i.nbDims != dim.nbDims) {
return tensorflow::errors::InvalidArgument(
- "Concatenate receives inputs with inconsistent dimensions, at " +
+ "Concatenate receives inputs with inconsistent dimensions, at ",
node_def.name());
-
+ }
for (int j = 0; j < dim.nbDims; j++) {
// check dimension consistency on non-concatenate axis
- if (j != index - 1 && dim_i.d[j] != dim.d[j])
+ if (j != index - 1 && dim_i.d[j] != dim.d[j]) {
return tensorflow::errors::InvalidArgument(
- "Concatenate receives inputs with inconsistent shape, at" +
+ "Concatenate receives inputs with inconsistent shape, at",
node_def.name());
+ }
}
- // TRT does concatenation only on channel!
- if (index != 1)
+#if NV_TENSORRT_MAJOR == 3
+ // TRT3 does concatenation only on channel!
+ if (index != 1) {
tensor_i = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor_i),
permutation_order);
-
+ TFTRT_RETURN_ERROR_IF_NULLPTR(tensor_i, node_def.name());
+ }
+#endif
inputs_vec.push_back(tensor_i);
}
@@ -1874,11 +2178,18 @@ tensorflow::Status ConvertConcat(Converter& ctx,
nvinfer1::IConcatenationLayer* layer = ctx.network()->addConcatenation(
const_cast<nvinfer1::ITensor* const*>(inputs_vec.data()),
inputs_vec.size());
+ TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
+#if NV_TENSORRT_MAJOR > 3
+ layer->setAxis(index - 1);
+#endif
nvinfer1::ITensor* output_tensor = layer->getOutput(0);
+#if NV_TENSORRT_MAJOR == 3
if (index != 1) {
output_tensor = ctx.TransposeTensor(output_tensor, permutation_order);
+ TFTRT_RETURN_ERROR_IF_NULLPTR(output_tensor, node_def.name());
}
+#endif
outputs->push_back(TRT_TensorOrWeights(output_tensor));
return tensorflow::Status::OK();
}
@@ -1997,112 +2308,249 @@ tensorflow::Status ConvertFusedBatchNorm(
combined_offset_weights.GetWeightsForTRT(),
combined_scale_weights.GetWeightsForTRT(),
dummy_power_weights.GetWeightsForTRT());
+ TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
nvinfer1::ITensor* output_tensor = layer->getOutput(0);
outputs->push_back(TRT_TensorOrWeights(output_tensor));
return tensorflow::Status::OK();
}
+#if NV_TENSORRT_MAJOR > 3
+tensorflow::Status ConvertMatMulHelper(
+ Converter& ctx, TRT_TensorOrWeights tensor_input,
+ TRT_ShapedWeights weights_raw, bool transpose_weight, string node_name,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ nvinfer1::ITensor* output_tensor;
+ if (!tensor_input.is_tensor()) {
+ return tensorflow::errors::InvalidArgument("Input 0 expects tensor");
+ }
+ const nvinfer1::ITensor* tensor = tensor_input.tensor();
+
+ TRT_ShapedWeights weights(weights_raw.type_);
+ if (transpose_weight) {
+ weights = weights_raw;
+ } else {
+ TRT_ShapedWeights weights_ck = weights_raw;
+ weights = ctx.get_temp_weights_like(weights_ck);
+ ReorderCKtoKC(weights_raw, &weights);
+ }
+ TRT_ShapedWeights biases(weights.type_);
+
+ int noutput = weights.shape_.d[0];
+
+ auto input_dim = tensor->getDimensions();
+ while (input_dim.nbDims != 3) {
+ input_dim.d[input_dim.nbDims++] = 1;
+ }
+ TFTRT_RETURN_ERROR_IF_FALSE(
+ PrepareTensorForShape(ctx, tensor_input, input_dim, &tensor), node_name);
+
+ nvinfer1::IFullyConnectedLayer* layer = ctx.network()->addFullyConnected(
+ *const_cast<nvinfer1::ITensor*>(tensor), noutput, weights, biases);
+ TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_name);
+ output_tensor = layer->getOutput(0);
+
+ const nvinfer1::ITensor* temp_tensor;
+ auto output_dim = output_tensor->getDimensions();
+ output_dim.nbDims = 1;
+ TFTRT_RETURN_ERROR_IF_FALSE(
+ PrepareTensorForShape(ctx, TRT_TensorOrWeights(output_tensor), output_dim,
+ &temp_tensor),
+ node_name);
+ output_tensor = const_cast<nvinfer1::ITensor*>(temp_tensor);
+ outputs->push_back(TRT_TensorOrWeights(output_tensor));
+ return tensorflow::Status::OK();
+}
+
+// inputs are both two dimensional (tensorflow::ops::MatMul)
tensorflow::Status ConvertMatMul(Converter& ctx,
const tensorflow::NodeDef& node_def,
const std::vector<TRT_TensorOrWeights>& inputs,
std::vector<TRT_TensorOrWeights>* outputs) {
+ if (!inputs.at(0).is_tensor()) {
+ return tensorflow::errors::InvalidArgument("Input 0 expects tensor, at" +
+ node_def.name());
+ }
+
const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
- // TODO(jie): transpose!
TFAttrs attrs(node_def);
- TRT_ShapedWeights weights_ck = inputs.at(1).weights();
- TRT_ShapedWeights weights = ctx.get_temp_weights_like(weights_ck);
- ReorderCKtoKC(weights_ck, &weights);
- TRT_ShapedWeights biases(weights.type_);
+ // TODO(jie): INT32 should be converted?
+ tensorflow::DataType tf_dtype = attrs.get<tensorflow::DataType>("T");
+ if (tf_dtype != tensorflow::DataType::DT_FLOAT &&
+ tf_dtype != tensorflow::DataType::DT_HALF) {
+ return tensorflow::errors::Unimplemented(
+ "data type is not supported, for node " + node_def.name() + " got " +
+ tensorflow::DataTypeString(tf_dtype));
+ }
- int noutput = weights.shape_.d[0];
+ bool transpose_a = attrs.get<bool>("transpose_a");
+ bool transpose_b = attrs.get<bool>("transpose_b");
- nvinfer1::IFullyConnectedLayer* layer = ctx.network()->addFullyConnected(
- *const_cast<nvinfer1::ITensor*>(tensor), noutput, weights, biases);
+ nvinfer1::ITensor* output_tensor;
- nvinfer1::ITensor* output_tensor = layer->getOutput(0);
- outputs->push_back(TRT_TensorOrWeights(output_tensor));
- return tensorflow::Status::OK();
+ // FullyConnected:
+ if (transpose_a) {
+ return tensorflow::errors::Internal(
+ "Transpose_a is not supported for TensorRT FullyConnected (op: " +
+ node_def.op() + "), at: " + node_def.name());
+ }
+ if (inputs.at(1).is_tensor()) {
+ return tensorflow::errors::Internal(
+ "Operand 1 must be constant for TensorRT FullyConnected (op: " +
+ node_def.op() + "), at: " + node_def.name());
+ }
+ return ConvertMatMulHelper(ctx, inputs.at(0), inputs.at(1).weights(),
+ transpose_b, node_def.name(), outputs);
}
-tensorflow::Status ConvertReshape(
+tensorflow::Status ConvertBatchMatMul(
Converter& ctx, const tensorflow::NodeDef& node_def,
const std::vector<TRT_TensorOrWeights>& inputs,
std::vector<TRT_TensorOrWeights>* outputs) {
- if (inputs.size() != 2 || !inputs.at(0).is_tensor() ||
- !inputs.at(1).is_weights())
- return tensorflow::errors::InvalidArgument(
- "Input expects tensor and weights, at" + node_def.name());
+ TFAttrs attrs(node_def);
- // implement tensor binaryOp weight [channel wise] for now;
- const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
- auto dims = tensor->getDimensions();
- // restore implicit batch dimension
+ // TODO(jie): INT32 should be converted?
+ tensorflow::DataType tf_dtype = attrs.get<tensorflow::DataType>("T");
+ if (tf_dtype != tensorflow::DataType::DT_FLOAT &&
+ tf_dtype != tensorflow::DataType::DT_HALF) {
+ return tensorflow::errors::Unimplemented(
+ "data type is not supported, for node " + node_def.name() + " got " +
+ tensorflow::DataTypeString(tf_dtype));
+ }
- TRT_ShapedWeights shape = inputs.at(1).weights();
+ bool transpose_a = attrs.get<bool>("adj_x");
+ bool transpose_b = attrs.get<bool>("adj_y");
- TFAttrs attrs(node_def);
+ auto dims = inputs.at(0).shape();
+ if (dims.nbDims == 1) { // NC * CK is only supported through fully connected
+ if (transpose_a == false && inputs.at(0).is_tensor() &&
+ inputs.at(1).is_weights()) {
+ return ConvertMatMulHelper(ctx, inputs.at(0), inputs.at(1).weights(),
+ transpose_b, node_def.name(), outputs);
+ } else {
+ return tensorflow::errors::InvalidArgument(
+ "Invalid configuration for MatMul, at: " + node_def.name());
+ }
+ }
- auto padding_type = attrs.get<tensorflow::DataType>("Tshape");
+ const nvinfer1::ITensor* tensor_l;
+ const nvinfer1::ITensor* tensor_r;
+ auto dims_l = inputs.at(0).shape();
+ auto dims_r = inputs.at(1).shape();
+ if (inputs.at(0).is_weights()) {
+ if (inputs.at(0).shape().d[0] != 1) {
+ return tensorflow::errors::InvalidArgument(
+ "Input 0 as weight assumes broadcast across batch for MatMul, at: " +
+ node_def.name());
+ } else {
+ for (int i = 0; i < dims_l.nbDims - 1; i++) {
+ dims_l.d[i] = dims_l.d[i + 1];
+ }
+ dims_l.nbDims--;
+ }
+ }
+ if (inputs.at(1).is_weights()) {
+ if (inputs.at(1).shape().d[0] != 1) {
+ return tensorflow::errors::InvalidArgument(
+ "Input 1 as weight assumes broadcast across batch for MatMul, at: " +
+ node_def.name());
+ } else {
+ for (int i = 0; i < dims_r.nbDims - 1; i++) {
+ dims_r.d[i] = dims_r.d[i + 1];
+ }
+ dims_r.nbDims--;
+ }
+ }
- if (shape.shape_.nbDims != 1)
- return tensorflow::errors::InvalidArgument(
- "reshape new shape is not 1 dimensional, at " + node_def.name());
+ TFTRT_RETURN_ERROR_IF_FALSE(
+ PrepareTensorForShape(ctx, inputs.at(0), dims_l, &tensor_l),
+ node_def.name());
+ TFTRT_RETURN_ERROR_IF_FALSE(
+ PrepareTensorForShape(ctx, inputs.at(1), dims_r, &tensor_r),
+ node_def.name());
- // Only expect to handle INT32 as attributes for now
- if (padding_type != tensorflow::DataType::DT_INT32)
- return tensorflow::errors::Unimplemented(
- "reshape new shape supports only DT_INT32, at " + node_def.name());
+ nvinfer1::IMatrixMultiplyLayer* layer = ctx.network()->addMatrixMultiply(
+ *const_cast<nvinfer1::ITensor*>(tensor_l), transpose_a,
+ *const_cast<nvinfer1::ITensor*>(tensor_r), transpose_b);
+ TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
+ nvinfer1::ITensor* output_tensor = layer->getOutput(0);
+ outputs->push_back(TRT_TensorOrWeights(output_tensor));
+ return tensorflow::Status::OK();
+}
+#endif
- auto shape_data = static_cast<int*>(const_cast<void*>(shape.GetValues()));
+#if NV_TENSORRT_MAJOR > 3
+tensorflow::Status ConvertSoftmax(
+ Converter& ctx, const tensorflow::NodeDef& node_def,
+ const std::vector<TRT_TensorOrWeights>& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
- if (shape_data[0] != -1)
+ int nbDims = tensor->getDimensions().nbDims;
+ if (nbDims == 0) {
return tensorflow::errors::InvalidArgument(
- "reshape new shape first dimension is not -1, at " + node_def.name());
+ "TensorRT Softmax cannot apply on batch dimension, at" +
+ node_def.name());
+ }
+ nvinfer1::ISoftMaxLayer* layer =
+ ctx.network()->addSoftMax(*const_cast<nvinfer1::ITensor*>(tensor));
+ TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
+ // Tensorflow SoftMax assumes applying softmax on the last dimension.
+ layer->setAxes(1 << (nbDims - 1));
- auto shape_num_dims = shape.shape_.d[0];
- VLOG(2) << "shape dimensions: " << shape_num_dims;
- int volume_w = 1;
- for (int i = 1; i < shape.shape_.d[0]; i++) volume_w *= shape_data[i];
+ nvinfer1::ITensor* output_tensor = layer->getOutput(0);
+ outputs->push_back(TRT_TensorOrWeights(output_tensor));
+ return tensorflow::Status::OK();
+}
+#endif
- int volume_t = 1;
- for (int i = 0; i < dims.nbDims; i++) volume_t *= dims.d[i];
+#if NV_TENSORRT_MAJOR > 3
+tensorflow::Status ConvertTopK(Converter& ctx,
+ const tensorflow::NodeDef& node_def,
+ const std::vector<TRT_TensorOrWeights>& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
- VLOG(2) << "volume: " << volume_t << " volume weights: " << volume_w;
- if (volume_w != volume_t)
+ int nbDims = tensor->getDimensions().nbDims;
+ if (nbDims == 0) {
return tensorflow::errors::InvalidArgument(
- "volume does not agree between tensor and new shape, at " +
- node_def.name());
+ "TensorRT TopK cannot apply on batch dimension, at" + node_def.name());
+ }
- nvinfer1::IShuffleLayer* layer =
- ctx.network()->addShuffle(*const_cast<nvinfer1::ITensor*>(tensor));
+ TRT_ShapedWeights k_w = inputs.at(1).weights();
+ int k = *(static_cast<int*>(const_cast<void*>(k_w.GetValues())));
- nvinfer1::Dims reshape_dims;
- VLOG(2) << "new dimension: " << shape_num_dims - 1;
- reshape_dims.nbDims = shape_num_dims - 1;
- for (int32_t i = 0; i < reshape_dims.nbDims; ++i) {
- reshape_dims.d[i] = shape_data[i + 1];
+ nvinfer1::TopKOperation op;
+ uint32_t reducedAxes = 0;
+ if (node_def.op() == "TopKV2") {
+ op = nvinfer1::TopKOperation::kMAX;
+ reducedAxes |= 1 << (nbDims - 1);
+ } else {
+ return tensorflow::errors::Unimplemented(
+ "Operation: " + node_def.op() +
+ " not implemented, at: " + node_def.name());
}
- layer->setReshapeDimensions(reshape_dims);
- VLOG(2) << "new dimension: " << shape_num_dims - 1;
- nvinfer1::ITensor* output_tensor = layer->getOutput(0);
- auto dims_output = output_tensor->getDimensions();
- VLOG(2) << "output tensor dimension:" << dims_output.nbDims;
- outputs->push_back(TRT_TensorOrWeights(output_tensor));
+ nvinfer1::ITopKLayer* layer = ctx.network()->addTopK(
+ *const_cast<nvinfer1::ITensor*>(tensor), op, k, reducedAxes);
+ TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
+
+ nvinfer1::ITensor* output_value_tensor = layer->getOutput(0);
+ nvinfer1::ITensor* output_indices_tensor = layer->getOutput(1);
+ outputs->push_back(TRT_TensorOrWeights(output_value_tensor));
+ outputs->push_back(TRT_TensorOrWeights(output_indices_tensor));
return tensorflow::Status::OK();
}
+#endif
void Converter::register_op_converters() {
// vgg_16 slim implementation
- op_registry_["Placeholder"] = ConvertPlaceholder;
op_registry_["Conv2D"] = ConvertConv2D;
op_registry_["DepthwiseConv2dNative"] = ConvertConv2DDepthwise;
op_registry_["Relu"] = ConvertActivation;
op_registry_["MaxPool"] = ConvertPool;
op_registry_["AvgPool"] = ConvertPool;
- // This could be really handled as ConvertBinary
op_registry_["BiasAdd"] = ConvertScale;
op_registry_["Const"] = ConvertConst;
// TODO(ben,jie): this is a temp hack.
@@ -2113,18 +2561,38 @@ void Converter::register_op_converters() {
op_registry_["Add"] = ConvertBinary;
op_registry_["Mul"] = ConvertBinary;
op_registry_["Sub"] = ConvertBinary;
- op_registry_["Rsqrt"] = ConvertUnary;
- op_registry_["Mean"] = ConvertReduce;
op_registry_["Pad"] = ConvertPad;
- // TODO(ben,jie): Add more ops
op_registry_["ConcatV2"] = ConvertConcat;
- op_registry_["MatMul"] = ConvertMatMul;
- op_registry_["Reshape"] = ConvertReshape;
op_registry_["FusedBatchNorm"] = ConvertFusedBatchNorm;
op_registry_["FusedBatchNormV2"] = ConvertFusedBatchNorm;
- plugin_converter_ = ConvertPlugin;
+ op_registry_["Div"] = ConvertBinary;
+ op_registry_["RealDiv"] = ConvertBinary;
+
+ op_registry_["Rsqrt"] = ConvertUnary;
+ op_registry_["Reciprocal"] = ConvertUnary;
+ op_registry_["Exp"] = ConvertUnary;
+ op_registry_["Log"] = ConvertUnary;
+ op_registry_["Sqrt"] = ConvertUnary;
+ op_registry_["Abs"] = ConvertUnary;
+ op_registry_["Neg"] = ConvertUnary;
+#if NV_TENSORRT_MAJOR == 3
+ op_registry_["Mean"] = ConvertReducePool;
+#endif
+#if NV_TENSORRT_MAJOR > 3
+ op_registry_["Sum"] = ConvertReduce;
+ op_registry_["Prod"] = ConvertReduce;
+ op_registry_["Max"] = ConvertReduce;
+ op_registry_["Min"] = ConvertReduce;
+ op_registry_["Mean"] = ConvertReduce;
+ op_registry_["Maximum"] = ConvertBinary;
+ op_registry_["Minimum"] = ConvertBinary;
+ op_registry_["Softmax"] = ConvertSoftmax;
+ op_registry_["MatMul"] = ConvertMatMul;
+ op_registry_["BatchMatMul"] = ConvertBatchMatMul;
+ op_registry_["TopKV2"] = ConvertTopK;
+#endif
}
} // namespace
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
index 8a17eb02f1..04d072f5d9 100644
--- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
+++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
@@ -316,6 +316,11 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx,
ctx->SetStatus(tensorflow::errors::InvalidArgument(
"INT8 inputs are not supported!"));
return;
+#if NV_TENSORRT_MAJOR > 3
+ case nvinfer1::DataType::kINT32:
+ buffers[binding_index] = (void*)(input_tensor.flat<int32>().data());
+ break;
+#endif
default:
LOG(ERROR) << "Unknown TRT data type: " << int(dtype);
ctx->SetStatus(tensorflow::errors::InvalidArgument(
@@ -368,6 +373,12 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx,
ctx->SetStatus(tensorflow::errors::InvalidArgument(
"INT8 outputs are not supported!"));
return;
+#if NV_TENSORRT_MAJOR > 3
+ case nvinfer1::DataType::kINT32:
+ buffers[binding_index] =
+ reinterpret_cast<void*>(output_tensor->flat<int32>().data());
+ break;
+#endif
default:
LOG(ERROR) << "Unknown TRT data type: " << static_cast<int>(dtype);
ctx->SetStatus(tensorflow::errors::InvalidArgument(
diff --git a/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc b/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc
index 383635f428..e0c7b62723 100644
--- a/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc
+++ b/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc
@@ -42,8 +42,14 @@ REGISTER_OP("TRTEngineOp")
.Attr("precision_mode: {'FP32', 'FP16', 'INT8', 'INT8CALIB'}")
.Attr("calibration_data: string = ''")
.Input("in_tensor: InT")
- .Output("out_tensor: OutT")
- .SetShapeFn(shape_inference::TRTEngineOpShapeInference);
+ .Output("out_tensor: OutT");
+// TODO(jie): TF requires concrete output shape for concrete input shapes.
+// This is tricky for batch dimension, since we cannot ensure which input
+// would carry the correct batch dimension (for the current stage of the
+// implementation, we do require all input tensor to carry the same batch
+// size, but this could change in the future). Hence we disable shape
+// inference function as a workaround.
+// .SetShapeFn(shape_inference::TRTEngineOpShapeInference);
} // namespace tensorflow
diff --git a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc
index 227ac120dd..f30dba59ad 100644
--- a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc
+++ b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc
@@ -28,36 +28,50 @@ limitations under the License.
namespace tensorflow {
namespace shape_inference {
-tensorflow::Status TRTEngineOpShapeInference(InferenceContext* context) {
- std::vector<tensorflow::TensorShape> shapes;
- for (int i = 0; i < context->num_outputs(); ++i) {
- context->set_output(i, context->UnknownShape());
+tensorflow::Status TRTEngineOpShapeInference(InferenceContext* c) {
+ for (int i = 0; i < c->num_outputs(); ++i) {
+ c->set_output(i, c->UnknownShape());
}
- auto status = context->GetAttr("input_shapes", &shapes);
- // it is ok to not to have shapes
- if (!status.ok()) return Status::OK();
- if ((int)shapes.size() != context->num_inputs()) return Status::OK();
- bool different_input = false;
- for (int i = 0; i < context->num_inputs(); ++i) {
- if (shapes.at(i) != context->input_tensor(i)->shape())
- different_input = true;
+
+ // Check the sanity of the input shapes.
+ std::vector<tensorflow::TensorShape> input_shapes;
+ TF_RETURN_IF_ERROR(c->GetAttr("input_shapes", &input_shapes));
+ if (input_shapes.size() != c->num_inputs()) {
+ return tensorflow::errors::InvalidArgument(
+ "The actual number of inputs doesn't match the number of input "
+ "shapes set in the attr: ",
+ c->num_inputs(), " vs ", input_shapes.size());
+ }
+ bool input_match = true;
+ for (int i = 0; i < c->num_inputs(); ++i) {
+ ShapeHandle handle;
+ TF_RETURN_IF_ERROR(
+ c->MakeShapeFromTensorShape(input_shapes.at(i), &handle));
+ ShapeHandle merged;
+ if (!c->Merge(c->input(i), handle, &merged).ok()) {
+ // Input shape doesn't match what was set in attr, fine.
+ input_match = false;
+ }
}
- if (different_input) return Status::OK();
- shapes.resize(0);
- status = context->GetAttr("output_shapes", &shapes);
- if (!status.ok()) return Status::OK();
- if ((int)shapes.size() != context->num_outputs()) return Status::OK();
- std::vector<ShapeHandle> shape_handles(shapes.size());
- for (size_t i = 0; i < shapes.size(); ++i) {
- status =
- context->MakeShapeFromTensorShape(shapes.at(i), &shape_handles.at(i));
- if (!status.ok()) return Status::OK();
+
+ // Check the sanity of the output shapes.
+ std::vector<tensorflow::TensorShape> output_shapes;
+ TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
+ if (output_shapes.size() != c->num_outputs()) {
+ return tensorflow::errors::InvalidArgument(
+ "The actual number of outputs doesn't match the number of output "
+ "shapes set in the attr: ",
+ c->num_outputs(), " vs ", output_shapes.size());
}
- for (int i = 0; i < context->num_outputs(); ++i) {
- context->set_output(i, shape_handles.at(i));
+ for (size_t i = 0; i < output_shapes.size(); ++i) {
+ ShapeHandle handle;
+ TF_RETURN_IF_ERROR(
+ c->MakeShapeFromTensorShape(output_shapes.at(i), &handle));
+ if (input_match) c->set_output(i, handle);
}
return Status::OK();
}
+
} // namespace shape_inference
} // namespace tensorflow
diff --git a/tensorflow/contrib/timeseries/python/timeseries/BUILD b/tensorflow/contrib/timeseries/python/timeseries/BUILD
index ec9a7861e7..7020989d68 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/BUILD
+++ b/tensorflow/contrib/timeseries/python/timeseries/BUILD
@@ -157,6 +157,7 @@ py_library(
py_test(
name = "head_test",
+ size = "large",
srcs = [
"head_test.py",
],
diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD
index c08f088be7..0044fde9d0 100644
--- a/tensorflow/contrib/tpu/BUILD
+++ b/tensorflow/contrib/tpu/BUILD
@@ -161,12 +161,43 @@ py_library(
)
py_library(
+ name = "keras_support",
+ srcs = [
+ "python/tpu/keras_support.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":tpu_lib",
+ ":tpu_py",
+ "//tensorflow/contrib/cluster_resolver:tpu_cluster_resolver_py",
+ "//tensorflow/contrib/distribute/python:tpu_strategy",
+ "//tensorflow/contrib/framework:framework_py",
+ "//tensorflow/contrib/tpu/proto:compilation_result_proto_py",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:linalg_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python:session",
+ "//tensorflow/python:tensor_spec",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python/estimator:model_fn",
+ "//tensorflow/python/keras:backend",
+ "//tensorflow/python/keras:engine",
+ "//tensorflow/python/keras:layers",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
name = "tpu_lib",
srcs = [
"python/tpu/__init__.py",
"python/tpu/bfloat16.py",
"python/tpu/device_assignment.py",
- "python/tpu/keras_support.py",
"python/tpu/session_support.py",
"python/tpu/topology.py",
"python/tpu/tpu.py",
diff --git a/tensorflow/contrib/tpu/__init__.py b/tensorflow/contrib/tpu/__init__.py
index dc90668559..d62338680e 100644
--- a/tensorflow/contrib/tpu/__init__.py
+++ b/tensorflow/contrib/tpu/__init__.py
@@ -45,6 +45,8 @@
@@RunConfig
@@InputPipelineConfig
@@TPUConfig
+
+@@bfloat16_scope
"""
from __future__ import absolute_import
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py
index 7541544382..722e31abb2 100644
--- a/tensorflow/contrib/tpu/python/tpu/keras_support.py
+++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py
@@ -59,10 +59,12 @@ from tensorflow.contrib.framework.python.framework import experimental
from tensorflow.contrib.tpu.proto import compilation_result_pb2 as tpu_compilation_result
from tensorflow.contrib.tpu.python.ops import tpu_ops
from tensorflow.contrib.tpu.python.tpu import tpu
+from tensorflow.contrib.tpu.python.tpu import tpu_function
from tensorflow.contrib.tpu.python.tpu import tpu_optimizer
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session as tf_session
from tensorflow.python.estimator import model_fn as model_fn_lib
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.keras import backend as K
@@ -71,7 +73,9 @@ from tensorflow.python.keras import optimizers as keras_optimizers
from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.layers import embeddings
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_linalg_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
@@ -99,6 +103,45 @@ class TPUEmbedding(embeddings.Embedding):
return math_ops.tensordot(inputs, self.embeddings, 1)
+class KerasCrossShardOptimizer(keras_optimizers.Optimizer):
+ """An optimizer that averages gradients across TPU shards."""
+
+ def __init__(self, opt, name='KerasCrossShardOptimizer'):
+ """Construct a new cross-shard optimizer.
+
+ Args:
+ opt: An existing `Optimizer` to encapsulate.
+ name: Optional name prefix for the operations created when applying
+ gradients. Defaults to "KerasCrossShardOptimizer".
+
+ Raises:
+ ValueError: If reduction is not a valid cross-shard reduction.
+ """
+ super(KerasCrossShardOptimizer, self).__init__()
+ self._name = name
+ self._opt = opt
+
+ def get_updates(self, loss, params):
+ logging.info('Get updates: %s', loss)
+ self._opt.get_gradients = self.get_gradients
+ return self._opt.get_updates(loss, params)
+
+ def get_gradients(self, loss, params):
+ num_shards = tpu_function.get_tpu_context().number_of_shards
+ grads = super(KerasCrossShardOptimizer, self).get_gradients(loss, params)
+ return [tpu_ops.cross_replica_sum(grad) / num_shards for grad in grads]
+
+ def set_weights(self, weights):
+ self._opt.set_weights()
+
+ def get_weights(self):
+ return self._opt.get_weights()
+
+ @property
+ def lr(self):
+ return self._opt.lr
+
+
class TPUModelOp(
collections.namedtuple('TPUModelOp', [
'compile_op', 'execute_op', 'infeed_tensors', 'infeed_op', 'outfeed_op'
@@ -113,8 +156,13 @@ def _valid_name(tensor_name):
def _replicated_optimizer(opt):
"""Wrap the optimizer `opt` with CrossShardOptimizer if applicable."""
- return keras_optimizers.TFOptimizer(
- optimizer=tpu_optimizer.CrossShardOptimizer(opt.optimizer))
+ if tpu_function.get_tpu_context().number_of_shards == 1:
+ return opt
+
+ if isinstance(opt, keras_optimizers.TFOptimizer):
+ return tpu_optimizer.CrossShardOptimizer(opt.optimizer)
+ else:
+ return KerasCrossShardOptimizer(opt)
class TPURewriteContext(object):
@@ -163,8 +211,51 @@ class TPURewriteContext(object):
self._default_placeholder = array_ops.placeholder
self._default_name_scope = ops.name_scope
self._default_make_variable = base_layer.make_variable
+ self._default_random_normal = random_ops.random_normal
+ self._default_qr = gen_linalg_ops.qr
array_ops.placeholder = _placeholder
+
+ # Replace random_ops.random_normal with a dummy function because
+ # `random_normal` isn't yet implemented on the TPU. Because these
+ # initialized values are overwritten by the CPU values, this is okay.
+ def random_normal(shape,
+ mean=0.0,
+ stddev=1.0,
+ dtype=dtypes.float32,
+ seed=None,
+ name=None):
+ del mean
+ del stddev
+ del seed
+ return array_ops.zeros(shape, dtype=dtype, name=name)
+
+ random_ops.random_normal = random_normal
+
+ # Replace gen_linalg_ops.qr because QR decomposition is not yet implemented.
+ # TODO(saeta): Remove qr override once we confirm the qr implementation is
+ # ok.
+ # pylint: disable=redefined-builtin
+ def qr(input, full_matrices=False, name=None):
+ """Dummy implementation of qr decomposition."""
+ del full_matrices # TODO(saeta): Properly handle the full matrix case.
+ input_shape = input.shape
+ if len(input_shape) < 2:
+ raise ValueError('Invalid shape passed to qr: %s' % input_shape)
+ p = min(input_shape[-1], input_shape[-2])
+ if len(input_shape) == 2:
+ q = array_ops.zeros((p, p), name=name)
+ r = array_ops.zeros(input_shape, name=name)
+ return (r, q)
+ elif len(input_shape) == 3:
+ n = input_shape[0]
+ q = array_ops.zeros((n, p, p), name=name)
+ r = array_ops.zeros(input_shape, name=name)
+ return (r, q)
+ else:
+ raise ValueError('Invalid shape passed to qr: %s' % input_shape)
+ gen_linalg_ops.qr = qr
+
ops.name_scope = _name_scope
base_layer.make_variable = variable_scope.get_variable
logging.info('Overriding default placeholder.')
@@ -174,6 +265,8 @@ class TPURewriteContext(object):
array_ops.placeholder = self._default_placeholder
ops.name_scope = self._default_name_scope
base_layer.make_variable = self._default_make_variable
+ random_ops.random_normal = self._default_random_normal
+ gen_linalg_ops.qr = self._default_qr
class TPUFunction(object):
@@ -195,6 +288,12 @@ class TPUFunction(object):
self._compilation_cache = {}
self._cloned_model = None
+ # Copy optimizer configuration. This is done prior to `_specialize_model`
+ # as the configuration may require evaluating variables in the CPU session.
+ self._optimizer_config = None
+ if not isinstance(self.model.optimizer, keras_optimizers.TFOptimizer):
+ self._optimizer_config = self.model.optimizer.get_config()
+
def _specialize_model(self, input_specs):
"""Specialize `self.model` (a Keras model) for the given input shapes."""
# Re-create our input and output layers inside our subgraph. They will be
@@ -236,11 +335,23 @@ class TPUFunction(object):
# Clone our CPU model, running within the TPU device context.
with TPURewriteContext(tpu_input_map):
- self._cloned_model = models.clone_model(self.model)
+ # TODO(power): Replicate variables.
+ with ops.device('/device:TPU:0'):
+ self._cloned_model = models.clone_model(self.model)
+
+ # Create a copy of the optimizer for this graph.
+ if isinstance(self.model.optimizer, keras_optimizers.TFOptimizer):
+ cloned_optimizer = keras_optimizers.TFOptimizer(
+ self.model.optimizer.optimizer)
+ else:
+ logging.info('Cloning %s %s', self.model.optimizer.__class__.__name__,
+ self._optimizer_config)
+ cloned_optimizer = self.model.optimizer.__class__.from_config(
+ self._optimizer_config)
if is_training or is_test:
self._cloned_model.compile(
- optimizer=_replicated_optimizer(self.model.optimizer),
+ optimizer=_replicated_optimizer(cloned_optimizer),
loss=self.model.loss,
loss_weights=self.model.loss_weights,
metrics=self.model.metrics,
@@ -365,8 +476,7 @@ class TPUFunction(object):
batch_size = inputs[0].shape[0]
assert batch_size % self._strategy.num_towers == 0, (
'batch_size must be divisible by strategy.num_towers (%s vs %s)' %
- (batch_size, self._strategy.num_towers)
- )
+ (batch_size, self._strategy.num_towers))
shard_size = batch_size // self._strategy.num_towers
input_list = []
for index in range(self._strategy.num_towers):
@@ -438,9 +548,8 @@ class TPUFunction(object):
outputs_per_replica = len(self._outfeed_spec)
for i in range(self._strategy.num_towers):
- output_group = outfeed_outputs[
- i * outputs_per_replica:(i+1) * outputs_per_replica
- ]
+ output_group = outfeed_outputs[i * outputs_per_replica:(i + 1) *
+ outputs_per_replica]
for j in range(outputs_per_replica):
outputs[j].append(output_group[j])
@@ -470,14 +579,16 @@ class KerasTPUModel(models.Model):
self._tpu_weights_initialized = False
self._graph = ops.Graph()
- cluster_resolver = tpu_cluster_resolver.TPUClusterResolver(
+ self._cluster_resolver = tpu_cluster_resolver.TPUClusterResolver(
tpu_name_or_address)
- cluster_spec = cluster_resolver.cluster_spec()
+ master = self._cluster_resolver.master()
+ cluster_spec = self._cluster_resolver.cluster_spec()
self._session = tf_session.Session(
graph=self._graph,
- target=cluster_resolver.master(),
+ target=master,
config=config_pb2.ConfigProto(isolate_session_state=True))
+ # TODO(saeta): Confirm the lines below work in ClusterSpec propagation env.
if cluster_spec:
self._session.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
@@ -529,11 +640,6 @@ class KerasTPUModel(models.Model):
sample_weight_mode, weighted_metrics,
target_tensors, **kwargs)
- # Keras optimizers are not compatible with TPU rewrite
- if not isinstance(self.optimizer, keras_optimizers.TFOptimizer):
- raise ValueError(
- 'Optimizer must be a TFOptimizer, got: %s' % self.optimizer)
-
def _make_train_function(self):
if not self.train_function:
self.train_function = TPUFunction(
@@ -615,10 +721,10 @@ class KerasTPUModel(models.Model):
K.set_session(default_session)
def shutdown(self):
- logging.info('Shutting down TPU session.')
- with self.tpu_session() as session:
- session.run(tpu.shutdown_system())
-
+ # TODO(b/111364423): Actually shut down the system.
+ logging.info('Skipping shutting down TPU system.')
+ # with self.tpu_session() as session:
+ # session.run(tpu.shutdown_system())
self._session.close()
@@ -687,6 +793,10 @@ def tpu_model(model, tpu_name_or_address=None, strategy=None):
Returns:
A new `KerasTPUModel` instance.
"""
+ # Force initialization of the CPU model.
+ model.get_weights()
+ model.reset_states()
+
_validate_shapes(model)
# TODO(xiejw): Validate TPU model. TPUModel only?
# TODO(xiejw): Validate replicas. Full or 1. Shall we allow subset?
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config.py b/tensorflow/contrib/tpu/python/tpu/tpu_config.py
index 6d7331e3c7..9e010922dc 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_config.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_config.py
@@ -23,8 +23,6 @@ import collections
import json
import os
-import numpy as np
-
from tensorflow.contrib.tpu.python.tpu import util as util_lib
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.estimator import run_config as run_config_lib
@@ -43,6 +41,7 @@ class InputPipelineConfig(object):
PER_SHARD_V1 = 1
PER_HOST_V1 = 2
PER_HOST_V2 = 3
+ BROADCAST = 4
# TODO(b/72511246) Provide a simplified api to configure model parallelism.
@@ -50,7 +49,7 @@ class TPUConfig(
collections.namedtuple('TPUConfig', [
'iterations_per_loop',
'num_shards',
- 'computation_shape',
+ 'num_cores_per_replica',
'per_host_input_for_training',
'tpu_job_name',
'initial_infeed_sleep_secs',
@@ -67,22 +66,22 @@ class TPUConfig(
case, this number equals the total number of TPU cores. For
model-parallelism, the total number of TPU cores equals
product(computation_shape) * num_shards.
- computation_shape: Defaults to `None`, which disables model parallelism. A
- list of size 3 which describes the shape of a model replica's block of
- cores. This is required by model-parallelism which enables partitioning
- the model to multiple cores. For example, [2, 2, 1] means the model is
- partitioned across 4 cores which span two cores in both x and y
- coordinates. Please refer to @{tf.contrib.tpu.Topology} for the
- geometry of a TPU mesh.
+ num_cores_per_replica: Defaults to `None`, which disables model parallelism.
+ An integer which describes the number of TPU cores per model replica. This
+ is required by model-parallelism which enables partitioning
+ the model to multiple cores. Currently num_cores_per_replica must be
+ 1, 2, 4, or 8.
per_host_input_for_training: If `True`, `PER_HOST_V1`, or `PER_HOST_V2`,
- `input_fn` is invoked per-host rather than per-core. With per-host input
- pipeline configuration, `input_fn` is invoked once on each host. With the
- per-core input pipeline configuration, it is invoked once for each core.
+ `input_fn` is invoked once on each host. With the per-core input pipeline
+ configuration, it is invoked once for each core.
With a global batch size `train_batch_size` in `TPUEstimator` constructor,
the batch size for each shard is `train_batch_size` // #hosts in the
`True` or `PER_HOST_V1` mode. In `PER_HOST_V2` mode, it is
- `train_batch_size` // #cores. With the per-core input pipeline
- configuration, the shard batch size is also `train_batch_size` // #cores.
+ `train_batch_size` // #cores. In `BROADCAST` mode, `input_fn` is only
+ invoked once on host 0 and the tensors are broadcasted to all other
+ replicas. The batch size equals to train_batch_size`. With the per-core
+ input pipeline configuration, the shard batch size is also
+ `train_batch_size` // #cores.
Note: per_host_input_for_training==PER_SHARD_V1 only supports mode.TRAIN.
tpu_job_name: The name of the TPU job. Typically, this name is auto-inferred
within TPUEstimator, however when using ClusterSpec propagation in more
@@ -99,7 +98,7 @@ class TPUConfig(
def __new__(cls,
iterations_per_loop=2,
num_shards=None,
- computation_shape=None,
+ num_cores_per_replica=None,
per_host_input_for_training=True,
tpu_job_name=None,
initial_infeed_sleep_secs=None):
@@ -112,19 +111,12 @@ class TPUConfig(
if num_shards is not None:
util_lib.check_positive_integer(num_shards, 'TPUConfig num_shards')
- # Check computation_shape
- if computation_shape is not None and len(computation_shape) != 3:
- raise ValueError(
- 'computation_shape must be a list with length 3 or None; got {}'.
- format(str(computation_shape)))
-
- if computation_shape is not None:
- computation_shape_array = np.asarray(computation_shape, dtype=np.int32)
- # This prevents any computation being replicated across multiple hosts, so
- # that each host feeds the same number of computations.
- if any(computation_shape_array < 1) or any(computation_shape_array > 2):
- raise ValueError('computation_shape elements can only be 1 or 2; got '
- 'computation_shape={}'.format(computation_shape))
+ # Parse computation_shape
+ if num_cores_per_replica is not None:
+ if num_cores_per_replica not in [1, 2, 4, 8]:
+ raise ValueError(
+ 'num_cores_per_replica must be 1, 2, 4, or 8; got {}'.format(
+ str(num_cores_per_replica)))
# per_host_input_for_training may be True, False, or integer in [1..3].
# Map legacy values (True, False) to numeric values.
@@ -144,7 +136,7 @@ class TPUConfig(
cls,
iterations_per_loop=iterations_per_loop,
num_shards=num_shards,
- computation_shape=computation_shape,
+ num_cores_per_replica=num_cores_per_replica,
per_host_input_for_training=per_host_input_for_training,
tpu_job_name=tpu_job_name,
initial_infeed_sleep_secs=initial_infeed_sleep_secs)
@@ -214,6 +206,12 @@ class RunConfig(run_config_lib.RunConfig):
self._session_config.cluster_def.CopyFrom(
self._cluster_spec.as_cluster_def())
+ def _maybe_overwrite_session_config_for_distributed_training(self):
+ # Overrides the parent class session_config overwrite for between-graph. TPU
+ # runs with in-graph, which should not have device filter. Doing nothing
+ # ("pass") basically disables it.
+ pass
+
@property
def evaluation_master(self):
return self._evaluation_master
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py b/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py
index 37ef3dbe1e..2326fe97a8 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import json
from tensorflow.contrib.tpu.python.tpu import tpu_config as tpu_config_lib
+from tensorflow.core.protobuf import config_pb2
from tensorflow.python.estimator import run_config as run_config_lib
from tensorflow.python.platform import test
@@ -33,6 +34,46 @@ def _set_tf_config_env_variable(tf_config):
class TPURunConfigTest(test.TestCase):
+ def test_no_session_config_set_in_local_case(self):
+ run_config = tpu_config_lib.RunConfig()
+ self.assertIsNone(run_config.session_config)
+
+ def test_no_session_config_overwrite_in_local_case(self):
+ session_config = config_pb2.ConfigProto(allow_soft_placement=True)
+ run_config = tpu_config_lib.RunConfig(session_config=session_config)
+ self.assertEqual(session_config, run_config.session_config)
+
+ def test_no_session_config_set_with_cluster_spec(self):
+ tf_config = {
+ 'cluster': {
+ run_config_lib.TaskType.CHIEF: ['host3:3'],
+ run_config_lib.TaskType.WORKER: ['host3:4']
+ },
+ 'task': {
+ 'type': run_config_lib.TaskType.CHIEF,
+ 'index': 0
+ }
+ }
+ with _set_tf_config_env_variable(tf_config):
+ run_config = tpu_config_lib.RunConfig()
+ self.assertIsNone(run_config.session_config)
+
+ def test_no_session_config_overwrite_with_cluster_spec(self):
+ tf_config = {
+ 'cluster': {
+ run_config_lib.TaskType.CHIEF: ['host3:3'],
+ run_config_lib.TaskType.WORKER: ['host3:4']
+ },
+ 'task': {
+ 'type': run_config_lib.TaskType.CHIEF,
+ 'index': 0
+ }
+ }
+ with _set_tf_config_env_variable(tf_config):
+ session_config = config_pb2.ConfigProto(allow_soft_placement=True)
+ run_config = tpu_config_lib.RunConfig(session_config=session_config)
+ self.assertEqual(session_config, run_config.session_config)
+
def test_fail_with_invalid_num_shards(self):
with self.assertRaisesRegexp(ValueError, 'must be positive'):
tpu_config_lib.RunConfig(
@@ -43,15 +84,11 @@ class TPURunConfigTest(test.TestCase):
tpu_config_lib.RunConfig(
tpu_config=tpu_config_lib.TPUConfig(iterations_per_loop=0))
- def test_fail_with_invalid_computation_shape(self):
- with self.assertRaisesRegexp(ValueError,
- 'computation_shape must be a list with length'
- ' 3 or None'):
- tpu_config_lib.TPUConfig(computation_shape=[2, 1])
-
- with self.assertRaisesRegexp(ValueError,
- 'computation_shape elements can only be'):
- tpu_config_lib.TPUConfig(computation_shape=[1, 3, 1])
+ def test_fail_with_invalid_num_cores_per_replica(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'num_cores_per_replica must be 1, 2, 4, or 8;'
+ ' got 7'):
+ tpu_config_lib.TPUConfig(num_cores_per_replica=7)
class TPURunConfigMasterTest(test.TestCase):
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_context.py b/tensorflow/contrib/tpu/python/tpu/tpu_context.py
index aec59f3885..e54395f05d 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_context.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_context.py
@@ -21,8 +21,6 @@ from __future__ import print_function
from contextlib import contextmanager
import copy
-import numpy as np
-
from tensorflow.contrib.tpu.python.tpu import device_assignment as tpu_device_assignment
from tensorflow.contrib.tpu.python.tpu import tpu_config
from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib
@@ -33,15 +31,26 @@ from tensorflow.python.platform import tf_logging as logging
_DEFAULT_JOB_NAME = 'tpu_worker'
_DEFAULT_COORDINATOR_JOB_NAME = 'coordinator'
_LOCAL_MASTERS = ('', 'local')
+_NUM_CORES_TO_COMPUTATION_SHAPE = {
+ 1: [1, 1, 1],
+ 2: [1, 1, 2],
+ 4: [1, 2, 2],
+ 8: [2, 2, 2]
+}
class TPUContext(object):
"""The context of current input_fn invocation."""
- def __init__(self, internal_ctx, input_device=None, invocation_index=None):
+ def __init__(self,
+ internal_ctx,
+ input_device=None,
+ invocation_index=None,
+ call_from_input_fn=True):
self._internal_ctx = internal_ctx
self._input_device = input_device
self._invocation_index = invocation_index
+ self._call_from_input_fn = call_from_input_fn
def current_input_fn_deployment(self):
"""The configuration of the current input_fn invocation.
@@ -69,11 +78,21 @@ class TPUContext(object):
total invocation count is equal to the number of hosts in the system
and num replicas consumed by current invocation is equal to number of
cores per host.
+
+ Raises:
+ RuntimeError: If this method must not be called from input_fn.
"""
+ if not self._call_from_input_fn:
+ raise RuntimeError('This TPUContext instance must not be called from'
+ ' model_fn.')
+
if self._internal_ctx.is_input_sharded_per_core():
total_invocation_count = (self._internal_ctx.num_hosts
* self._internal_ctx.num_of_replicas_per_host)
replicas_consumed = 1
+ elif self._internal_ctx.is_input_broadcast_with_iterators():
+ total_invocation_count = 1
+ replicas_consumed = self._internal_ctx.num_replicas
else:
total_invocation_count = self._internal_ctx.num_hosts
replicas_consumed = self._internal_ctx.num_of_replicas_per_host
@@ -105,6 +124,14 @@ class TPUContext(object):
'num_of_replicas_per_host is not supported for model_parallelism')
return self._internal_ctx.num_of_replicas_per_host
+ @property
+ def device_assignment(self):
+ """Returns device_assignment object."""
+ if self._call_from_input_fn:
+ raise RuntimeError('This TPUContext instance must not be called from'
+ ' input_fn.')
+ return self._internal_ctx.device_assignment
+
def device_for_replica(self, replica_id):
"""Returns the tuple of (CPU device and device ordinal) for replica.
@@ -121,8 +148,8 @@ class TPUContext(object):
# as far as model is replicated to all cores in the system.
# If the precise replica_id to device mapping is required, please
- # set the computation_shape as [1,1,1] in TPUConfig to enable
- # the model parallelism.
+ # set the num_cores_per_replica to 1 in TPUConfig to enable the
+ # model parallelism.
if self._internal_ctx.model_parallelism_enabled:
return RuntimeError(
'device_for_replica is not yet implemented for model parallelism. '
@@ -175,9 +202,14 @@ class _InternalTPUContext(object):
self._eval_on_tpu = eval_on_tpu
self._model_parallelism_enabled = (
- use_tpu and config.tpu_config.computation_shape)
+ use_tpu and config.tpu_config.num_cores_per_replica)
self._mode = None
-
+ num_cores_per_replica = config.tpu_config.num_cores_per_replica
+ if num_cores_per_replica:
+ self._computation_shape = _NUM_CORES_TO_COMPUTATION_SHAPE[
+ num_cores_per_replica]
+ else:
+ self._computation_shape = None
self._lazy_tpu_system_metadata_dict = {} # key by master address
self._lazy_device_assignment_dict = {} # key by master address
self._lazy_validation_dict = {} # key by ModeKeys
@@ -202,7 +234,7 @@ class _InternalTPUContext(object):
def mode(self):
return self._assert_mode()
- def _get_master_address(self):
+ def master_address(self):
mode = self._assert_mode()
config = self._config
master = (
@@ -212,7 +244,7 @@ class _InternalTPUContext(object):
def _get_tpu_system_metadata(self):
"""Gets the (maybe cached) TPU system metadata."""
- master = self._get_master_address()
+ master = self.master_address()
tpu_system_metadata = self._lazy_tpu_system_metadata_dict.get(master)
if tpu_system_metadata is not None:
return tpu_system_metadata
@@ -229,7 +261,7 @@ class _InternalTPUContext(object):
def _get_device_assignment(self):
"""Gets the (maybe cached) TPU device assignment."""
- master = self._get_master_address()
+ master = self.master_address()
device_assignment = self._lazy_device_assignment_dict.get(master)
if device_assignment is not None:
return device_assignment
@@ -238,11 +270,12 @@ class _InternalTPUContext(object):
device_assignment = tpu_device_assignment.device_assignment(
tpu_system_metadata.topology,
- computation_shape=self._config.tpu_config.computation_shape,
+ computation_shape=self._computation_shape,
num_replicas=self.num_replicas)
- logging.info('computation_shape: %s',
- str(self._config.tpu_config.computation_shape))
+ logging.info('num_cores_per_replica: %s',
+ str(self._config.tpu_config.num_cores_per_replica))
+ logging.info('computation_shape: %s', str(self._computation_shape))
logging.info('num_replicas: %d', self.num_replicas)
logging.info('device_assignment.topology.device_coordinates: %s',
str(device_assignment.topology.device_coordinates))
@@ -283,23 +316,20 @@ class _InternalTPUContext(object):
num_cores_in_system = self.num_cores
if self.model_parallelism_enabled:
- computation_shape_array = np.asarray(
- self._config.tpu_config.computation_shape, dtype=np.int32)
- num_cores_per_replica = np.prod(computation_shape_array)
+ num_cores_per_replica = self._config.tpu_config.num_cores_per_replica
if num_cores_per_replica > num_cores_in_system:
raise ValueError(
'The num of cores required by the model parallelism, specified by '
- 'TPUConfig.computation_shape, is larger than the total num of '
- 'TPU cores in the system. computation_shape: {}, num cores '
- 'in the system: {}'.format(
- self._config.tpu_config.computation_shape,
- num_cores_in_system))
+ 'TPUConfig.num_cores_per_replica, is larger than the total num of '
+ 'TPU cores in the system. num_cores_per_replica: {}, num cores '
+ 'in the system: {}'.format(num_cores_per_replica,
+ num_cores_in_system))
if num_cores_in_system % num_cores_per_replica != 0:
raise RuntimeError(
'The num of cores in the system ({}) is not divisible by the num '
'of cores ({}) required by the model parallelism, specified by '
- 'TPUConfig.computation_shape. This should never happen!'.format(
+ 'TPUConfig.num_cores_per_replica. This should never happen!'.format(
num_cores_in_system, num_cores_per_replica))
return num_cores_in_system // num_cores_per_replica
@@ -327,6 +357,11 @@ class _InternalTPUContext(object):
return (self._config.tpu_config.per_host_input_for_training is
tpu_config.InputPipelineConfig.PER_HOST_V2)
+ def is_input_broadcast_with_iterators(self):
+ """Return true if input_fn should be run in the full_replicae config."""
+ return (self._config.tpu_config.per_host_input_for_training is
+ tpu_config.InputPipelineConfig.BROADCAST)
+
def is_running_on_cpu(self, is_export_mode=False):
"""Determines whether the input_fn and model_fn should be invoked on CPU.
@@ -391,7 +426,7 @@ class _InternalTPUContext(object):
"""Returns the shard batch size for `input_fn`."""
global_batch_size = self.global_batch_size
- if self.is_running_on_cpu():
+ if (self.is_running_on_cpu() or self.is_input_broadcast_with_iterators()):
return global_batch_size
# On TPU
@@ -406,7 +441,7 @@ class _InternalTPUContext(object):
"""Returns the shard batch size for `model_fn`."""
global_batch_size = self.global_batch_size
- if self.is_running_on_cpu():
+ if (self.is_running_on_cpu() or self.is_input_broadcast_with_iterators()):
return global_batch_size
# On TPU. always sharded per shard.
@@ -463,17 +498,23 @@ class _InternalTPUContext(object):
master = self.master_job
- def _placement_function(_sentinal=None, core_id=None, host_id=None): # pylint: disable=invalid-name
+ def _placement_function(_sentinal=None, replica_id=None, host_id=None): # pylint: disable=invalid-name
+ """Return the host device given replica_id or host_id."""
assert _sentinal is None
- if core_id is not None and host_id is not None:
+ if replica_id is not None and host_id is not None:
raise RuntimeError(
- 'core_id and host_id can have only one non-None value.')
+ 'replica_id and host_id can have only one non-None value.')
if master is None:
return '/replica:0/task:0/device:CPU:0'
else:
- if core_id is not None:
- host_id = core_id / self.num_of_cores_per_host
+ if replica_id is not None:
+ if self.model_parallelism_enabled:
+ return self.device_assignment.host_device(
+ replica=replica_id, job=master)
+ else:
+ host_id = replica_id / self.num_of_cores_per_host
+
return '/job:%s/task:%d/device:CPU:0' % (master, host_id)
return _placement_function
@@ -546,9 +587,9 @@ class _InternalTPUContext(object):
'be ({}), got ({}). For non-model-parallelism, num_replicas should '
'be the total num of TPU cores in the system. For '
'model-parallelism, the total number of TPU cores should be '
- 'product(computation_shape) * num_replicas. Please set it '
+ 'num_cores_per_replica * num_replicas. Please set it '
'accordingly or leave it as `None`'.format(
- self._get_master_address(), num_replicas,
+ self.master_address(), num_replicas,
user_provided_num_replicas))
raise ValueError(message)
@@ -603,7 +644,7 @@ class _OneCoreTPUContext(_InternalTPUContext):
def _get_tpu_system_metadata(self):
"""Gets the (maybe cached) TPU system metadata."""
- master = self._get_master_address()
+ master = self.master_address()
tpu_system_metadata = self._lazy_tpu_system_metadata_dict.get(master)
if tpu_system_metadata is not None:
return tpu_system_metadata
@@ -625,7 +666,7 @@ def _get_tpu_context(config, train_batch_size, eval_batch_size,
"""Returns an instance of `_InternalTPUContext`."""
if (config.tpu_config.num_shards == 1 and
- config.tpu_config.computation_shape is None):
+ config.tpu_config.num_cores_per_replica is None):
logging.warning(
'Setting TPUConfig.num_shards==1 is an unsupported behavior. '
'Please fix as soon as possible (leaving num_shards as None.')
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index 49cd318b89..aa407cf4d8 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -43,6 +43,7 @@ from tensorflow.contrib.training.python.training import hparam
from tensorflow.core.framework import variable_pb2
from tensorflow.core.framework.summary_pb2 import Summary
from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.client import session as session_lib
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.estimator import model_fn as model_fn_lib
@@ -67,6 +68,7 @@ from tensorflow.python.saved_model import tag_constants
from tensorflow.python.summary import summary
from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import evaluation
+from tensorflow.python.training import monitored_session
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training
from tensorflow.python.training import training_util
@@ -382,7 +384,14 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
def begin(self):
logging.info('TPU job name %s', self._master_job)
self._iterations_per_loop_var = _create_or_get_iterations_per_loop()
- self._init_ops = [tpu.initialize_system(job=self._master_job)]
+ self._init_ops = []
+ # For distributed sessions, we can't run initialize_system in a separate
+ # graph here because 'begin' is only invoked when the MonitoredSession is
+ # created. We need to reinitialize the system every time MonitoredSession
+ # creates an underlying tf.Session, so we initialize from Scaffold.finalize.
+ # See _get_and_wrap_scaffold for more details.
+ if self._master_job is None:
+ self._init_ops.append(tpu.initialize_system(job=self._master_job))
self._finalize_ops = [tpu.shutdown_system(job=self._master_job)]
summary_writer_init_ops = contrib_summary.summary_writer_initializer_op()
@@ -484,7 +493,7 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
return _OpQueueContext(name=name, target=target, args=args)
def after_create_session(self, session, coord):
- logging.info('Init TPU system')
+ logging.info('Running init_ops')
session.run(self._init_ops,
options=config_pb2.RunOptions(timeout_in_ms=5 * 60 * 1000))
@@ -847,6 +856,65 @@ def generate_per_host_v2_enqueue_ops_fn_for_host(
return enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset
+def generate_broadcast_enqueue_ops_fn(ctx, input_fn, inputs_structure_recorder,
+ num_hosts):
+ """Generates infeed enqueue ops for one input_fn on all the hosts."""
+ captured_infeed_queue = _CapturedObject()
+ hooks = []
+ device_0 = ctx.tpu_host_placement_function(host_id=0)
+ with ops.device(device_0):
+ user_context = tpu_context.TPUContext(
+ internal_ctx=ctx, input_device=device_0, invocation_index=0)
+ inputs = _Inputs.from_input_fn(input_fn(user_context))
+
+ is_dataset = inputs.is_dataset
+ if ctx.mode == model_fn_lib.ModeKeys.PREDICT:
+ raise TypeError('Mode PREDICT not yet supported in BROADCAST mode.')
+
+ hooks.append(inputs.dataset_initializer_hook())
+ num_replicas_per_host = ctx.num_of_replicas_per_host
+
+ def tpu_ordinal_function_impl(replica_id):
+ if ctx.device_assignment:
+ return ctx.device_assignment.tpu_ordinal(replica_id=replica_id)
+ else:
+ return replica_id % num_replicas_per_host
+
+ def device_function_impl(replica_id):
+ return ctx.tpu_host_placement_function(replica_id=replica_id)
+
+ def enqueue_ops_fn():
+ """Generates enqueue ops for all the hosts."""
+ broadcasted_inputs = []
+ flattened_inputs = None # Cache result from input_fn.
+ for host_id in xrange(num_hosts):
+ with ops.device(ctx.tpu_host_placement_function(host_id=host_id)):
+ for _ in xrange(ctx.num_of_replicas_per_host):
+ # Note: input_fn is only called once at host 0 for the first replica.
+ # The features and labels returned from that invocation are
+ # broadcasted to other replicas(including the replicas on other
+ # hosts).
+ if flattened_inputs is None:
+ features, labels = inputs.features_and_labels() # Calls get_next()
+ inputs_structure_recorder.validate_and_record_structure(
+ features, labels)
+ flattened_inputs = (
+ inputs_structure_recorder.flatten_features_and_labels(
+ features, labels))
+ broadcasted_inputs.append(flattened_inputs)
+
+ infeed_queue = tpu_feed.InfeedQueue(
+ number_of_tuple_elements=len(broadcasted_inputs[0]))
+ captured_infeed_queue.capture(infeed_queue)
+ enqueue_ops = infeed_queue.generate_enqueue_ops(
+ broadcasted_inputs,
+ tpu_ordinal_function=tpu_ordinal_function_impl,
+ placement_function=device_function_impl)
+ return enqueue_ops
+
+ return enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset
+
+
class _InputPipeline(object):
"""`_InputPipeline` handles invoking `input_fn` and piping to infeed queue.
@@ -1079,6 +1147,22 @@ class _InputPipeline(object):
# Infeed_queue_getter must be called after enqueue_ops_fn is called.
infeed_queues.append(captured_infeed_queue.get())
+ elif self._ctx.is_input_broadcast_with_iterators():
+ # Only calls input_fn in host 0.
+ host_device = tpu_host_placement_fn(host_id=0)
+ enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset = (
+ generate_broadcast_enqueue_ops_fn(self._ctx, self._input_fn,
+ self._inputs_structure_recorder,
+ num_hosts))
+ all_hooks.extend(hooks)
+ if is_dataset:
+ run_infeed_loop_on_coordinator = False
+ enqueue_ops.append(
+ _wrap_computation_in_while_loop(
+ device=host_device, op_fn=enqueue_ops_fn))
+ else:
+ enqueue_ops.append(enqueue_ops_fn())
+ infeed_queues.append(captured_infeed_queue.get())
else:
for host_id in range(num_hosts):
host_device = tpu_host_placement_fn(host_id=host_id)
@@ -1422,6 +1506,11 @@ class _ModelFnWrapper(object):
running_on_cpu = self._ctx.is_running_on_cpu(is_export_mode)
_add_item_to_params(params, _USE_TPU_KEY, not running_on_cpu)
+ if not running_on_cpu:
+ user_context = tpu_context.TPUContext(
+ internal_ctx=self._ctx, call_from_input_fn=False)
+ _add_item_to_params(params, _CTX_KEY, user_context)
+
estimator_spec = self._model_fn(features=features, **kwargs)
if (running_on_cpu and
isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec)): # pylint: disable=protected-access
@@ -1601,7 +1690,7 @@ class _OutfeedHostCall(object):
# place all ops on tpu host if possible.
#
# TODO(jhseu): Evaluate whether this is right for summaries.
- with ops.device(self._ctx.tpu_host_placement_function(core_id=0)):
+ with ops.device(self._ctx.tpu_host_placement_function(replica_id=0)):
for name in self._names:
dequeue_ops = dequeue_ops_by_name[name]
for i, item in enumerate(dequeue_ops):
@@ -1986,7 +2075,7 @@ class TPUEstimator(estimator_lib.Estimator):
if (config.tpu_config.per_host_input_for_training is
tpu_config.InputPipelineConfig.PER_SHARD_V1 and
- config.tpu_config.computation_shape):
+ config.tpu_config.num_cores_per_replica):
raise ValueError(
'Model parallelism only supports per host input for training. '
'Please adjust TPURunconfig.per_host_input_for_training.')
@@ -2298,10 +2387,20 @@ class TPUEstimator(estimator_lib.Estimator):
# Clear the bit.
self._is_input_fn_invoked = None
+ # examples_hook is added to training_hooks for both CPU and TPU
+ # execution.
+ examples_hook = ExamplesPerSecondHook(
+ ctx.global_batch_size,
+ output_dir=self.model_dir,
+ every_n_steps=self._log_every_n_steps)
+
if ctx.is_running_on_cpu(is_export_mode=is_export_mode):
logging.info('Running %s on CPU', mode)
- return model_fn_wrapper.call_without_tpu(
+ estimator_spec = model_fn_wrapper.call_without_tpu(
features, labels, is_export_mode=is_export_mode)
+ estimator_spec = estimator_spec._replace(
+ training_hooks=estimator_spec.training_hooks + (examples_hook,))
+ return estimator_spec
assert labels is None, '`labels` passed to `model_fn` must be `None`.'
# TPUEstimator._call_input_fn passes `input_fn` as features to here.
@@ -2369,10 +2468,6 @@ class TPUEstimator(estimator_lib.Estimator):
},
every_n_iter=logging_hook_frequency)
])
- examples_hook = ExamplesPerSecondHook(
- ctx.global_batch_size,
- output_dir=self.model_dir,
- every_n_steps=self._log_every_n_steps)
examples_hook._set_steps_per_run( # pylint: disable=protected-access
self._config.tpu_config.iterations_per_loop)
hooks.append(examples_hook)
@@ -2614,7 +2709,7 @@ def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
outputs_from_all_shards=False,
device_assignment=ctx.device_assignment)
- scaffold = _get_scaffold(captured_scaffold_fn)
+ scaffold = _get_and_wrap_scaffold(captured_scaffold_fn, ctx)
return loss, host_calls, scaffold
@@ -2637,7 +2732,7 @@ def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
outputs_from_all_shards=False,
device_assignment=ctx.device_assignment)
- scaffold = _get_scaffold(captured_scaffold_fn)
+ scaffold = _get_and_wrap_scaffold(captured_scaffold_fn, ctx)
return loss, host_call, scaffold
@@ -2665,7 +2760,7 @@ def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
num_shards=num_cores,
outputs_from_all_shards=False)
- scaffold = _get_scaffold(captured_scaffold_fn)
+ scaffold = _get_and_wrap_scaffold(captured_scaffold_fn, ctx)
return dummy_predict_op, host_calls, scaffold
@@ -2755,8 +2850,20 @@ class _CapturedObject(object):
return self._object
-def _get_scaffold(captured_scaffold_fn):
- """Retrieves the Scaffold from `captured_scaffold_fn`."""
+def _get_and_wrap_scaffold(captured_scaffold_fn, ctx):
+ """Retrieves the Scaffold from `captured_scaffold_fn`.
+
+ Also wraps the scaffold's finalize method to initialize the TPU after the
+ graph is finalized.
+
+ Args:
+ captured_scaffold_fn: a `_CapturedObject` containing a scaffold_fn.
+ ctx: A `_InternalTPUContext` instance used to initialize the TPU.
+
+ Returns:
+ The Scaffold produced by captured_scaffold_fn, wrapped to initialize the TPU
+ after the graph is finalized.
+ """
with _CapturingContext(message='Inside scaffold_fn'):
scaffold_fn = captured_scaffold_fn.get()
if scaffold_fn:
@@ -2767,14 +2874,64 @@ def _get_scaffold(captured_scaffold_fn):
else:
scaffold = None
- if scaffold:
- wrapped_finalize = scaffold.finalize
-
- def _finalize():
- with _CapturingContext('Inside Scaffold.finalize'):
- wrapped_finalize()
-
- scaffold.finalize = _finalize
+ if scaffold is None:
+ # When master_address is None, we are using DirectSession, so we can't
+ # invoke initialize_system from finalize. See comments below.
+ if ctx.master_address() is None:
+ return scaffold
+ scaffold = monitored_session.Scaffold()
+
+ wrapped_finalize = scaffold.finalize
+
+ def _finalize():
+ """Invoke wrapped_finalize and initialize the TPU."""
+ with _CapturingContext('Inside Scaffold.finalize'):
+ wrapped_finalize()
+ # Run tpu.initialize_system in its own graph after finalizing the main graph
+ # for distributed sessions. This is necessary because the TPU must be
+ # initialized before the TPU graph rewrite pass runs. We can't put the
+ # initialization op in the main graph because the main graph also contains
+ # replicate ops created by tpu.shard. If we tried to run initialization from
+ # the main graph, the TPU graph rewrite pass would rewrite the replicate ops
+ # before actually evaluating the initialization ops.
+ #
+ # For distributed sessions, the master may independently restart. After a
+ # master restarts, the rewrite pass runs again when any op in the main graph
+ # runs, so we must reinitialize the system every time the main graph is
+ # finalized.
+ #
+ # Special case: When master_address is unset, we're using DirectSession.
+ # DirectSession resets device state between sessions, and uses
+ # place_pruned_graph. Initialization currently passes state to replication
+ # through the TPU_SYSTEM resource manager. Under DirectSession, this
+ # resource manager gets reset when init_session is closed, so DirectSession
+ # can't initialize here, and must instead initialize from the main graph's
+ # init_ops. This is possible with DirectSession because it uses
+ # place_pruned_graph, which removes unreferenced ops before invoking the
+ # rewrite pass. This makes it possible to run init_ops from the main graph,
+ # which contains both tpu.initialize_system and tpu.shard ops, without first
+ # triggering the TPU graph rewrite. We can't do this for distributed
+ # sessions because they don't support place_pruned_graph.
+ #
+ # TODO(b/110943344) Clean this up as part of the initialize_system dataflow
+ # cleanup. It should be possible to remove the special case for
+ # DirectSession and the other call to initialize_system from
+ # _obtain_topology, when topology info is always explicitly passed from
+ # tpu.initialize_system to tpu.shard, though this requires editing or
+ # rebuilding the main graph each time the master restarts.
+ if ctx.master_address() is None:
+ return
+ with ops.Graph().as_default():
+ logging.info('Init TPU system master_address %s', ctx.master_address())
+ with session_lib.Session(
+ ctx.master_address(),
+ config=ctx.config.session_config) as init_session:
+ run_options = config_pb2.RunOptions(timeout_in_ms=5 * 60 * 1000)
+ init_session.run(
+ tpu.initialize_system(job=ctx.master_job), options=run_options)
+ logging.info('TPU system initialized')
+
+ scaffold.finalize = _finalize
return scaffold
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_feed.py b/tensorflow/contrib/tpu/python/tpu/tpu_feed.py
index 604e6600c8..a44b4f4622 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_feed.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_feed.py
@@ -461,7 +461,10 @@ class InfeedQueue(object):
name=full_name,
device_ordinal=tpu_ordinal)
- def generate_enqueue_ops(self, sharded_inputs, tpu_ordinal_function=None):
+ def generate_enqueue_ops(self,
+ sharded_inputs,
+ tpu_ordinal_function=None,
+ placement_function=None):
"""Generates the host-side Ops to enqueue the shards of a tuple.
sharded_inputs is a list, one for each shard, of lists of
@@ -483,6 +486,9 @@ class InfeedQueue(object):
shard index as input and returns the ordinal of the TPU device
the shard's infeed should be placed on. tpu_ordinal_function must be
set if the inputs are placed on CPU devices.
+ placement_function: if not None, a function that takes the shard index as
+ input and returns the host device where the enqueue op should be placed
+ on.
Returns:
A list of host-side Ops, one for each shard, that when executed together
@@ -508,8 +514,12 @@ class InfeedQueue(object):
tpu_ordinal_function = lambda index: -1
name_prefix = "%s/enqueue" % self._name
return [
- self._generate_enqueue_op(shard, name_prefix, index,
- tpu_ordinal=tpu_ordinal_function(index))
+ self._generate_enqueue_op(
+ shard,
+ name_prefix,
+ index,
+ tpu_ordinal=tpu_ordinal_function(index),
+ device=placement_function(index) if placement_function else None)
for (shard, index) in zip(sharded_inputs, xrange(self.number_of_shards))
]
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py b/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py
index 15f99d7eeb..53d33f4077 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py
@@ -23,6 +23,7 @@ import collections
from tensorflow.contrib.tpu.python.ops import tpu_ops
from tensorflow.contrib.tpu.python.tpu import tpu_function
+from tensorflow.python.framework import ops
from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import optimizer
@@ -153,8 +154,9 @@ class CrossShardOptimizer(optimizer.Optimizer):
if grad is None:
summed_grads_and_vars.append((grad, var))
else:
- summed_grads_and_vars.append((tpu_ops.cross_replica_sum(
- grad, self._group_assignment), var))
+ with ops.colocate_with(grad):
+ summed_grads_and_vars.append((tpu_ops.cross_replica_sum(
+ grad, self._group_assignment), var))
return self._opt.apply_gradients(summed_grads_and_vars, global_step, name)
def get_slot(self, *args, **kwargs):
diff --git a/tensorflow/contrib/training/python/training/sgdr_learning_rate_decay.py b/tensorflow/contrib/training/python/training/sgdr_learning_rate_decay.py
new file mode 100644
index 0000000000..ed0f398e30
--- /dev/null
+++ b/tensorflow/contrib/training/python/training/sgdr_learning_rate_decay.py
@@ -0,0 +1,187 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""SGDR learning rate decay function."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import math_ops, control_flow_ops
+
+
+def sgdr_decay(learning_rate, global_step, initial_period_steps,
+ t_mul=2.0, m_mul=1.0, name=None):
+ """Implements Stochastic Gradient Descent with Warm Restarts (SGDR).
+
+ As described in "SGDR: Stochastic Gradient Descent
+ with Warm Restarts" by Ilya Loshchilov & Frank Hutter, Proceedings of
+ ICLR'2017, available at https://arxiv.org/pdf/1608.03983.pdf
+
+ The learning rate decreases according to cosine annealing:
+
+ ```python
+ learning_rate * 0.5 * (1 + cos(x_val * pi)) # for x_val defined in [0, 1]
+ ```
+
+ Thus, at the beginning (when the restart index i = 0),
+ the learning rate decreases for `initial_period_steps` steps from the initial
+ learning rate `learning_rate` (when `x_val=0`, we get `cos(0)=1`) to
+ 0 (when `x_val=1`, we get `cos(pi)=-1`).
+
+ The decrease within the i-th period takes `t_i` steps,
+ where `t_0` = `initial_period_steps` is the user-defined number of batch
+ iterations (not epochs as in the paper) to be performed before the first
+ restart is launched.
+
+ Then, we perform the first restart (i=1) by setting the learning rate to
+ `learning_rate*(m_mul^i)`, where `m_mul in [0,1]` (set to 1 by default).
+ The i-th restart runs for `t_i=t_0*(t_mul^i)` steps, i.e., every new
+ restart runs `t_mul` times longer than the previous one.
+
+ Importantly, when one has no access to a validation set, SGDR suggests
+ to report the best expected / recommended solution in the following way:
+ When we are within our initial run (i=0), every new solution represents
+ SGDR's recommended solution. Instead, when i>0, the recommended solution is
+ the one obtained at the end of each restart.
+
+ Note that the minimum learning rate is set to 0 for simplicity,
+ you can adjust the code to deal with any positive minimum learning rate
+ as defined in the paper.
+
+ `initial_period_steps` is the duration of the first period measured in terms
+ of number of minibatch updates. If one wants to use epochs, one should compute
+ the number of updates required for an epoch.
+
+ For example, assume the following parameters and intention:
+ Minibatch size: 100
+ Training dataset size: 10000
+ If the user wants the first decay period to span across 5 epochs, then
+ `initial_period_steps` = 5 * 10000/100 = 500
+
+ Train for 10000 batch iterations with the initial learning rate set to
+ 0.1, then restart to run 2 times longer, i.e, for 20000 batch iterations
+ and with the initial learning rate 0.05, then restart again and again,
+ doubling the runtime of each new period and with two times smaller
+ initial learning rate.
+
+ To accomplish the above, one would write:
+
+ ```python
+ ...
+ global_step = tf.Variable(0, trainable=False)
+ starter_learning_rate = 0.1
+ learning_rate = sgdr_decay(starter_learning_rate, global_step,
+ initial_period_steps=10000, t_mul=2, m_mul=0.5)
+ # Passing global_step to minimize() will increment it at each step.
+ learning_step = (
+ tf.train.GradientDescentOptimizer(learning_rate)
+ .minimize(...my loss..., global_step=global_step)
+ )
+
+ # Step | 0 | 1000 | 5000 | 9000 | 9999 | 10000 | 11000 |
+ # LR | 0.1 | 0.097 | 0.05 | 0.002 | 0.00 | 0.05 | 0.0496 |
+
+ # Step | 20000 | 29000 | 29999 | 30000 |
+ # LR | 0.025 | 0.0003 | 0.00 | 0.025 |
+ ```
+
+ Args:
+ learning_rate: A scalar `float32` or `float64` `Tensor` or a
+ Python number. The initial learning rate.
+ global_step: A scalar `int32` or `int64` `Tensor` or a Python number.
+ Global step to use for the decay computation. Must not be negative.
+ initial_period_steps: Duration of the first period measured as the number
+ of minibatch updates, if one wants to use epochs, one should compute
+ the number of updates required for an epoch.
+ t_mul: A scalar `float32` or `float64` `Tensor` or a Python number.
+ Must be positive.
+ Used to derive the number of iterations in the i-th period:
+ `initial_period_steps * (t_mul^i)`. Defaults to 2.0.
+ m_mul: A scalar `float32` or `float64` `Tensor` or a Python number.
+ Must be positive.
+ Used to derive the initial learning rate of the i-th period:
+ `learning_rate * (m_mul^i)`. Defaults to 1.0
+
+ Returns:
+ A scalar `Tensor` of the same type as `learning_rate`.
+ The learning rate for a provided global_step.
+ Raises:
+ ValueError: if `global_step` is not supplied.
+ """
+
+ if global_step is None:
+ raise ValueError("global_step is required for sgdr_decay.")
+ with ops.name_scope(name, "SGDRDecay",
+ [learning_rate, global_step,
+ initial_period_steps, t_mul, m_mul]) as name:
+ learning_rate = ops.convert_to_tensor(learning_rate,
+ name="initial_learning_rate")
+ dtype = learning_rate.dtype
+ global_step = math_ops.cast(global_step, dtype)
+ t_0 = math_ops.cast(initial_period_steps, dtype)
+ t_mul = math_ops.cast(t_mul, dtype)
+ m_mul = math_ops.cast(m_mul, dtype)
+
+ c_one = math_ops.cast(constant_op.constant(1.0), dtype)
+ c_half = math_ops.cast(constant_op.constant(0.5), dtype)
+ c_pi = math_ops.cast(constant_op.constant(math.pi), dtype)
+
+ # Find normalized value of the current step
+ x_val = math_ops.div(global_step, t_0)
+
+ def compute_step(x_val, geometric=False):
+ if geometric:
+ # Consider geometric series where t_mul != 1
+ # 1 + t_mul + t_mul^2 ... = (1 - t_mul^i_restart) / (1 - t_mul)
+
+ # First find how many restarts were performed for a given x_val
+ # Find maximal integer i_restart value for which this equation holds
+ # x_val >= (1 - t_mul^i_restart) / (1 - t_mul)
+ # x_val * (1 - t_mul) <= (1 - t_mul^i_restart)
+ # t_mul^i_restart <= (1 - x_val * (1 - t_mul))
+
+ # tensorflow allows only log with base e
+ # i_restart <= log(1 - x_val * (1 - t_mul) / log(t_mul)
+ # Find how many restarts were performed
+
+ i_restart = math_ops.floor(
+ math_ops.log(c_one - x_val * (c_one - t_mul)) / math_ops.log(t_mul))
+ # Compute the sum of all restarts before the current one
+ sum_r = (c_one - t_mul ** i_restart) / (c_one - t_mul)
+ # Compute our position within the current restart
+ x_val = (x_val - sum_r) / t_mul ** i_restart
+
+ else:
+ # Find how many restarts were performed
+ i_restart = math_ops.floor(x_val)
+ # Compute our position within the current restart
+ x_val = x_val - i_restart
+ return i_restart, x_val
+
+ i_restart, x_val = control_flow_ops.cond(
+ math_ops.equal(t_mul, c_one),
+ lambda: compute_step(x_val, geometric=False),
+ lambda: compute_step(x_val, geometric=True))
+
+ # If m_mul < 1, then the initial learning rate of every new restart will be
+ # smaller, i.e., by a factor of m_mul ** i_restart at i_restart-th restart
+ m_fac = learning_rate * (m_mul ** i_restart)
+
+ return math_ops.multiply(c_half * m_fac,
+ (math_ops.cos(x_val * c_pi) + c_one), name=name)
diff --git a/tensorflow/contrib/training/python/training/sgdr_learning_rate_decay_test.py b/tensorflow/contrib/training/python/training/sgdr_learning_rate_decay_test.py
new file mode 100644
index 0000000000..4a46e9a49e
--- /dev/null
+++ b/tensorflow/contrib/training/python/training/sgdr_learning_rate_decay_test.py
@@ -0,0 +1,145 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Functional test for sgdr learning rate decay."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+from sgdr_learning_rate_decay import sgdr_decay
+from tensorflow.python.platform import googletest
+from tensorflow.python.framework import test_util
+from tensorflow.python.framework import dtypes
+from tensorflow import placeholder
+
+
+class SGDRDecayTest(test_util.TensorFlowTestCase):
+ """Unit tests for SGDR learning rate decay."""
+
+ def get_original_values(self, lr, t_e, mult_factor, iter_per_epoch, epochs):
+ """Get an array with learning rate values from the consecutive steps using
+ the original implementation
+ (https://github.com/loshchil/SGDR/blob/master/SGDR_WRNs.py)."""
+ t0 = math.pi / 2.0
+ tt = 0
+ te_next = t_e
+
+ lr_values = []
+ sh_lr = lr
+ for epoch in range(epochs):
+ for _ in range(iter_per_epoch):
+ # In the original approach training function is executed here
+ lr_values.append(sh_lr)
+ dt = 2.0 * math.pi / float(2.0 * t_e)
+ tt = tt + float(dt) / iter_per_epoch
+ if tt >= math.pi:
+ tt = tt - math.pi
+ cur_t = t0 + tt
+ new_lr = lr * (1.0 + math.sin(cur_t)) / 2.0 # lr_min = 0, lr_max = lr
+ sh_lr = new_lr
+ if (epoch + 1) == te_next: # time to restart
+ sh_lr = lr
+ tt = 0 # by setting to 0 we set lr to lr_max, see above
+ t_e = t_e * mult_factor # change the period of restarts
+ te_next = te_next + t_e # note the next restart's epoch
+
+ return lr_values
+
+ def get_sgdr_values(self, lr, initial_period_steps, t_mul, iters):
+ """Get an array with learning rate values from the consecutive steps
+ using current tensorflow implementation."""
+ with self.test_session():
+ step = placeholder(dtypes.int32)
+
+ decay = sgdr_decay(lr, step, initial_period_steps, t_mul)
+ lr_values = []
+ for i in range(iters):
+ lr_values.append(decay.eval(feed_dict={step: i}))
+
+ return lr_values
+
+ def testCompareToOriginal(self):
+ """Compare values generated by tensorflow implementation to the values
+ generated by the original implementation
+ (https://github.com/loshchil/SGDR/blob/master/SGDR_WRNs.py)."""
+ with self.test_session():
+ lr = 10.0
+ init_steps = 2
+ t_mul = 3
+ iters = 10
+ epochs = 50
+
+ org_lr = self.get_original_values(lr, init_steps, t_mul, iters, epochs)
+ sgdr_lr = self.get_sgdr_values(lr, init_steps*iters, t_mul, iters*epochs)
+
+ for org, sgdr in zip(org_lr, sgdr_lr):
+ self.assertAllClose(org, sgdr)
+
+ def testMDecay(self):
+ """Test m_mul argument. Check values for learning rate at the beginning
+ of the first, second, third and fourth period. """
+ with self.test_session():
+ step = placeholder(dtypes.int32)
+
+ lr = 0.1
+ t_e = 10
+ t_mul = 3
+ m_mul = 0.9
+
+ decay = sgdr_decay(lr, step, t_e, t_mul, m_mul)
+
+ test_step = 0
+ self.assertAllClose(decay.eval(feed_dict={step: test_step}),
+ lr)
+
+ test_step = t_e
+ self.assertAllClose(decay.eval(feed_dict={step: test_step}),
+ lr * m_mul)
+
+ test_step = t_e + t_e*t_mul
+ self.assertAllClose(decay.eval(feed_dict={step: test_step}),
+ lr * m_mul**2)
+
+ test_step = t_e + t_e*t_mul + t_e * (t_mul**2)
+ self.assertAllClose(decay.eval(feed_dict={step: test_step}),
+ lr * (m_mul**3))
+
+ def testCos(self):
+ """Check learning rate values at the beginning, in the middle
+ and at the end of the period."""
+ with self.test_session():
+ step = placeholder(dtypes.int32)
+ lr = 0.2
+ t_e = 1000
+ t_mul = 1
+
+ decay = sgdr_decay(lr, step, t_e, t_mul)
+
+ test_step = 0
+ self.assertAllClose(decay.eval(feed_dict={step: test_step}), lr)
+
+ test_step = t_e//2
+ self.assertAllClose(decay.eval(feed_dict={step: test_step}), lr/2)
+
+ test_step = t_e
+ self.assertAllClose(decay.eval(feed_dict={step: test_step}), lr)
+
+ test_step = t_e*3//2
+ self.assertAllClose(decay.eval(feed_dict={step: test_step}), lr/2)
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 97880219b8..dbe87a6dbb 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -150,7 +150,6 @@ load(
"//third_party/mkl:build_defs.bzl",
"if_mkl",
)
-load("@io_bazel_rules_closure//closure:defs.bzl", "closure_proto_library")
exports_files(["ops/ops.pbtxt"])
@@ -334,6 +333,7 @@ filegroup(
"platform/init_main.h",
"platform/mem.h",
"platform/mutex.h",
+ "platform/numa.h",
"platform/thread_annotations.h",
],
visibility = ["//visibility:private"],
@@ -1923,7 +1923,6 @@ tf_proto_library_cc(
srcs = ["protobuf/master_service.proto"],
has_services = 1,
cc_api_version = 2,
- cc_grpc_version = 1,
cc_stubby_versions = ["2"],
protodeps = [":master_proto"],
visibility = [
@@ -1953,8 +1952,10 @@ LIB_INTERNAL_PRIVATE_HEADERS = ["framework/resource_handle.h"] + glob(
"**/*test*",
"lib/gif/**/*",
"lib/jpeg/**/*",
+ "lib/png/**/*",
"platform/gif.h",
"platform/jpeg.h",
+ "platform/png.h",
"platform/**/cuda.h",
"platform/**/stream_executor.h",
],
@@ -2049,6 +2050,7 @@ cc_library(
"lib/hash/crc32c_accelerate.cc",
"lib/gif/**/*",
"lib/jpeg/**/*",
+ "lib/png/**/*",
"platform/**/env_time.cc",
"platform/**/cuda_libdevice_path.cc",
"platform/**/device_tracer.cc",
@@ -2145,6 +2147,39 @@ cc_library(
)
cc_library(
+ name = "png_internal",
+ srcs = ["lib/png/png_io.cc"],
+ hdrs = [
+ "lib/bfloat16/bfloat16.h",
+ "lib/core/casts.h",
+ "lib/core/stringpiece.h",
+ "lib/png/png_io.h",
+ "platform/byte_order.h",
+ "platform/cpu_info.h",
+ "platform/default/integral_types.h",
+ "platform/default/logging.h",
+ "platform/logging.h",
+ "platform/macros.h",
+ "platform/platform.h",
+ "platform/png.h",
+ "platform/types.h",
+ ],
+ copts = tf_copts(),
+ linkopts = select({
+ "//tensorflow:freebsd": [],
+ "//tensorflow:windows": [],
+ "//tensorflow:windows_msvc": [],
+ "//conditions:default": ["-ldl"],
+ }),
+ deps = [
+ ":lib",
+ ":lib_internal",
+ "//tensorflow/core/platform/default/build_config:png",
+ "@zlib_archive//:zlib",
+ ],
+)
+
+cc_library(
name = "tflite_portable_logging",
srcs = [],
hdrs = [
@@ -3238,6 +3273,28 @@ tf_cc_test(
)
tf_cc_test(
+ name = "platform_numa_test",
+ size = "small",
+ srcs = ["platform/numa_test.cc"],
+ tags = [
+ # This test will not pass unless it has access to all NUMA nodes
+ # on the executing machine.
+ "manual",
+ "notap",
+ ],
+ deps = [
+ ":framework",
+ ":lib",
+ ":lib_internal",
+ ":lib_test_internal",
+ ":protos_all_cc",
+ ":test",
+ ":test_main",
+ "//third_party/eigen3",
+ ],
+)
+
+tf_cc_test(
name = "platform_setround_test",
size = "small",
srcs = ["platform/setround_test.cc"],
@@ -3601,6 +3658,7 @@ tf_cc_test_mkl(
deps = [
":core",
":core_cpu",
+ ":core_cpu_internal",
":framework",
":framework_internal",
":test",
diff --git a/tensorflow/core/api_def/base_api/api_def_IteratorFromStringHandleV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_IteratorFromStringHandleV2.pbtxt
new file mode 100644
index 0000000000..9d464b2aea
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_IteratorFromStringHandleV2.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "IteratorFromStringHandleV2"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_IteratorV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_IteratorV2.pbtxt
new file mode 100644
index 0000000000..becc729016
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_IteratorV2.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "IteratorV2"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_NonMaxSuppressionWithOverlaps.pbtxt b/tensorflow/core/api_def/base_api/api_def_NonMaxSuppressionWithOverlaps.pbtxt
new file mode 100644
index 0000000000..180edb15a4
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_NonMaxSuppressionWithOverlaps.pbtxt
@@ -0,0 +1,62 @@
+op {
+ graph_op_name: "NonMaxSuppressionWithOverlaps"
+ in_arg {
+ name: "overlaps"
+ description: <<END
+A 2-D float tensor of shape `[num_boxes, num_boxes]` representing
+the n-by-n box overlap values.
+END
+ }
+ in_arg {
+ name: "scores"
+ description: <<END
+A 1-D float tensor of shape `[num_boxes]` representing a single
+score corresponding to each box (each row of boxes).
+END
+ }
+ in_arg {
+ name: "max_output_size"
+ description: <<END
+A scalar integer tensor representing the maximum number of
+boxes to be selected by non max suppression.
+END
+ }
+ in_arg {
+ name: "overlap_threshold"
+ description: <<END
+A 0-D float tensor representing the threshold for deciding whether
+boxes overlap too.
+END
+ }
+ in_arg {
+ name: "score_threshold"
+ description: <<END
+A 0-D float tensor representing the threshold for deciding when to remove
+boxes based on score.
+END
+ }
+ out_arg {
+ name: "selected_indices"
+ description: <<END
+A 1-D integer tensor of shape `[M]` representing the selected
+indices from the boxes tensor, where `M <= max_output_size`.
+END
+ }
+ summary: "Greedily selects a subset of bounding boxes in descending order of score,"
+ description: <<END
+pruning away boxes that have high overlaps
+with previously selected boxes. Bounding boxes with score less than
+`score_threshold` are removed. N-by-n overlap values are supplied as square matrix,
+which allows for defining a custom overlap criterium (eg. intersection over union,
+intersection over area, etc.).
+
+The output of this operation is a set of integers indexing into the input
+collection of bounding boxes representing the selected boxes. The bounding
+box coordinates corresponding to the selected indices can then be obtained
+using the `tf.gather operation`. For example:
+
+ selected_indices = tf.image.non_max_suppression_with_overlaps(
+ overlaps, scores, max_output_size, overlap_threshold, score_threshold)
+ selected_boxes = tf.gather(boxes, selected_indices)
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_IdentityDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_SinkDataset.pbtxt
index ff2854fd2c..b5758ddbfb 100644
--- a/tensorflow/core/api_def/base_api/api_def_IdentityDataset.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SinkDataset.pbtxt
@@ -1,5 +1,5 @@
op {
- graph_op_name: "IdentityDataset"
+ graph_op_name: "SinkDataset"
visibility: HIDDEN
in_arg {
name: "input_dataset"
diff --git a/tensorflow/core/api_def/python_api/api_def_Acos.pbtxt b/tensorflow/core/api_def/python_api/api_def_Acos.pbtxt
index ca1ee78526..1fd8baf05f 100644
--- a/tensorflow/core/api_def/python_api/api_def_Acos.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Acos.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "acos"
- deprecation_message: "tf.acos is deprecated, please use tf.math.acos instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Acosh.pbtxt b/tensorflow/core/api_def/python_api/api_def_Acosh.pbtxt
index 7503353e41..f7946652ef 100644
--- a/tensorflow/core/api_def/python_api/api_def_Acosh.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Acosh.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "acosh"
- deprecation_message: "tf.acosh is deprecated, please use tf.math.acosh instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Add.pbtxt b/tensorflow/core/api_def/python_api/api_def_Add.pbtxt
index cc5d68b15d..fb505a91ac 100644
--- a/tensorflow/core/api_def/python_api/api_def_Add.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Add.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "add"
- deprecation_message: "tf.add is deprecated, please use tf.math.add instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_AsString.pbtxt b/tensorflow/core/api_def/python_api/api_def_AsString.pbtxt
index 9306eaf373..ea65543a76 100644
--- a/tensorflow/core/api_def/python_api/api_def_AsString.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_AsString.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "as_string"
- deprecation_message: "tf.as_string is deprecated, please use tf.dtypes.as_string instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Asin.pbtxt b/tensorflow/core/api_def/python_api/api_def_Asin.pbtxt
index 7622af7b45..eedf4553c6 100644
--- a/tensorflow/core/api_def/python_api/api_def_Asin.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Asin.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "asin"
- deprecation_message: "tf.asin is deprecated, please use tf.math.asin instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Asinh.pbtxt b/tensorflow/core/api_def/python_api/api_def_Asinh.pbtxt
index 395275c21d..10c2fb356e 100644
--- a/tensorflow/core/api_def/python_api/api_def_Asinh.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Asinh.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "asinh"
- deprecation_message: "tf.asinh is deprecated, please use tf.math.asinh instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Atan.pbtxt b/tensorflow/core/api_def/python_api/api_def_Atan.pbtxt
index dfcd632558..03dd5dc848 100644
--- a/tensorflow/core/api_def/python_api/api_def_Atan.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Atan.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "atan"
- deprecation_message: "tf.atan is deprecated, please use tf.math.atan instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Atan2.pbtxt b/tensorflow/core/api_def/python_api/api_def_Atan2.pbtxt
index fba79507aa..85b27bd881 100644
--- a/tensorflow/core/api_def/python_api/api_def_Atan2.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Atan2.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "atan2"
- deprecation_message: "tf.atan2 is deprecated, please use tf.math.atan2 instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Atanh.pbtxt b/tensorflow/core/api_def/python_api/api_def_Atanh.pbtxt
index f7164c33e8..ee7c0600d6 100644
--- a/tensorflow/core/api_def/python_api/api_def_Atanh.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Atanh.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "atanh"
- deprecation_message: "tf.atanh is deprecated, please use tf.math.atanh instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_BatchToSpaceND.pbtxt b/tensorflow/core/api_def/python_api/api_def_BatchToSpaceND.pbtxt
index 56e49a2221..9552fc92e3 100644
--- a/tensorflow/core/api_def/python_api/api_def_BatchToSpaceND.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_BatchToSpaceND.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "batch_to_space_nd"
- deprecation_message: "tf.batch_to_space_nd is deprecated, please use tf.manip.batch_to_space_nd instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Betainc.pbtxt b/tensorflow/core/api_def/python_api/api_def_Betainc.pbtxt
index 7c37b534c7..7ad7cbcba9 100644
--- a/tensorflow/core/api_def/python_api/api_def_Betainc.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Betainc.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "betainc"
- deprecation_message: "tf.betainc is deprecated, please use tf.math.betainc instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Ceil.pbtxt b/tensorflow/core/api_def/python_api/api_def_Ceil.pbtxt
index 0c72cf2edd..f2265bad56 100644
--- a/tensorflow/core/api_def/python_api/api_def_Ceil.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Ceil.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "ceil"
- deprecation_message: "tf.ceil is deprecated, please use tf.math.ceil instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_CheckNumerics.pbtxt b/tensorflow/core/api_def/python_api/api_def_CheckNumerics.pbtxt
index 7ea52d30b6..541b09a591 100644
--- a/tensorflow/core/api_def/python_api/api_def_CheckNumerics.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_CheckNumerics.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "check_numerics"
- deprecation_message: "tf.check_numerics is deprecated, please use tf.debugging.check_numerics instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Cholesky.pbtxt b/tensorflow/core/api_def/python_api/api_def_Cholesky.pbtxt
index 568fab4037..942f4e6ed8 100644
--- a/tensorflow/core/api_def/python_api/api_def_Cholesky.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Cholesky.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "cholesky"
- deprecation_message: "tf.cholesky is deprecated, please use tf.linalg.cholesky instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Cos.pbtxt b/tensorflow/core/api_def/python_api/api_def_Cos.pbtxt
index 6550cd2d4e..1af8c0c2c9 100644
--- a/tensorflow/core/api_def/python_api/api_def_Cos.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Cos.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "cos"
- deprecation_message: "tf.cos is deprecated, please use tf.math.cos instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Cosh.pbtxt b/tensorflow/core/api_def/python_api/api_def_Cosh.pbtxt
index ef82a45a80..2de87df40d 100644
--- a/tensorflow/core/api_def/python_api/api_def_Cosh.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Cosh.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "cosh"
- deprecation_message: "tf.cosh is deprecated, please use tf.math.cosh instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Cross.pbtxt b/tensorflow/core/api_def/python_api/api_def_Cross.pbtxt
index 33c1b8c617..e8a871cae6 100644
--- a/tensorflow/core/api_def/python_api/api_def_Cross.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Cross.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "cross"
- deprecation_message: "tf.cross is deprecated, please use tf.linalg.cross instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_DecodeBase64.pbtxt b/tensorflow/core/api_def/python_api/api_def_DecodeBase64.pbtxt
index 55c43ceba2..8b96eee631 100644
--- a/tensorflow/core/api_def/python_api/api_def_DecodeBase64.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_DecodeBase64.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "decode_base64"
- deprecation_message: "tf.decode_base64 is deprecated, please use tf.io.decode_base64 instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_DecodeCompressed.pbtxt b/tensorflow/core/api_def/python_api/api_def_DecodeCompressed.pbtxt
index 5f6be24cc4..829608fc8f 100644
--- a/tensorflow/core/api_def/python_api/api_def_DecodeCompressed.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_DecodeCompressed.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "decode_compressed"
- deprecation_message: "tf.decode_compressed is deprecated, please use tf.io.decode_compressed instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_DecodeJSONExample.pbtxt b/tensorflow/core/api_def/python_api/api_def_DecodeJSONExample.pbtxt
index 3759047f57..9f28bc5f59 100644
--- a/tensorflow/core/api_def/python_api/api_def_DecodeJSONExample.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_DecodeJSONExample.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "decode_json_example"
- deprecation_message: "tf.decode_json_example is deprecated, please use tf.io.decode_json_example instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_DecodeRaw.pbtxt b/tensorflow/core/api_def/python_api/api_def_DecodeRaw.pbtxt
index a83f702dca..0010a59ca4 100644
--- a/tensorflow/core/api_def/python_api/api_def_DecodeRaw.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_DecodeRaw.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "decode_raw"
- deprecation_message: "tf.decode_raw is deprecated, please use tf.io.decode_raw instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Dequantize.pbtxt b/tensorflow/core/api_def/python_api/api_def_Dequantize.pbtxt
index c9b4f76fab..5edd0c216b 100644
--- a/tensorflow/core/api_def/python_api/api_def_Dequantize.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Dequantize.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "dequantize"
- deprecation_message: "tf.dequantize is deprecated, please use tf.quantization.dequantize instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Diag.pbtxt b/tensorflow/core/api_def/python_api/api_def_Diag.pbtxt
index 2043facfa9..cba30e63e8 100644
--- a/tensorflow/core/api_def/python_api/api_def_Diag.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Diag.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "diag"
- deprecation_message: "tf.diag is deprecated, please use tf.linalg.tensor_diag instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_DiagPart.pbtxt b/tensorflow/core/api_def/python_api/api_def_DiagPart.pbtxt
index 7fa30b2347..54e1f34e82 100644
--- a/tensorflow/core/api_def/python_api/api_def_DiagPart.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_DiagPart.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "diag_part"
- deprecation_message: "tf.diag_part is deprecated, please use tf.linalg.tensor_diag_part instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Digamma.pbtxt b/tensorflow/core/api_def/python_api/api_def_Digamma.pbtxt
index 03f57678a8..91b4dfead7 100644
--- a/tensorflow/core/api_def/python_api/api_def_Digamma.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Digamma.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "digamma"
- deprecation_message: "tf.digamma is deprecated, please use tf.math.digamma instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_EncodeBase64.pbtxt b/tensorflow/core/api_def/python_api/api_def_EncodeBase64.pbtxt
index 47b4ab4da4..71bb73cfb2 100644
--- a/tensorflow/core/api_def/python_api/api_def_EncodeBase64.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_EncodeBase64.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "encode_base64"
- deprecation_message: "tf.encode_base64 is deprecated, please use tf.io.encode_base64 instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Equal.pbtxt b/tensorflow/core/api_def/python_api/api_def_Equal.pbtxt
index 2630962f7d..78aa1b3bc5 100644
--- a/tensorflow/core/api_def/python_api/api_def_Equal.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Equal.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "equal"
- deprecation_message: "tf.equal is deprecated, please use tf.math.equal instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Erfc.pbtxt b/tensorflow/core/api_def/python_api/api_def_Erfc.pbtxt
index 6a511b3251..e96df0c596 100644
--- a/tensorflow/core/api_def/python_api/api_def_Erfc.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Erfc.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "erfc"
- deprecation_message: "tf.erfc is deprecated, please use tf.math.erfc instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Exp.pbtxt b/tensorflow/core/api_def/python_api/api_def_Exp.pbtxt
index e1fd718ff0..70323fe5b4 100644
--- a/tensorflow/core/api_def/python_api/api_def_Exp.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Exp.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "exp"
- deprecation_message: "tf.exp is deprecated, please use tf.math.exp instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Expm1.pbtxt b/tensorflow/core/api_def/python_api/api_def_Expm1.pbtxt
index ca25706407..8ddf9d4d70 100644
--- a/tensorflow/core/api_def/python_api/api_def_Expm1.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Expm1.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "expm1"
- deprecation_message: "tf.expm1 is deprecated, please use tf.math.expm1 instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_ExtractImagePatches.pbtxt b/tensorflow/core/api_def/python_api/api_def_ExtractImagePatches.pbtxt
index d302e26ad2..f008b1222d 100644
--- a/tensorflow/core/api_def/python_api/api_def_ExtractImagePatches.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_ExtractImagePatches.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "extract_image_patches"
- deprecation_message: "tf.extract_image_patches is deprecated, please use tf.image.extract_image_patches instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_FFT.pbtxt b/tensorflow/core/api_def/python_api/api_def_FFT.pbtxt
index 57a00a08e3..d79e936b71 100644
--- a/tensorflow/core/api_def/python_api/api_def_FFT.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_FFT.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "fft"
- deprecation_message: "tf.fft is deprecated, please use tf.spectral.fft instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxArgs.pbtxt b/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxArgs.pbtxt
index cd14b13675..d8db83331f 100644
--- a/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxArgs.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxArgs.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "fake_quant_with_min_max_args"
- deprecation_message: "tf.fake_quant_with_min_max_args is deprecated, please use tf.quantization.fake_quant_with_min_max_args instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxArgsGradient.pbtxt b/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxArgsGradient.pbtxt
index d55cb69d1d..74f01d1a0c 100644
--- a/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxArgsGradient.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxArgsGradient.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "fake_quant_with_min_max_args_gradient"
- deprecation_message: "tf.fake_quant_with_min_max_args_gradient is deprecated, please use tf.quantization.fake_quant_with_min_max_args_gradient instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVars.pbtxt b/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVars.pbtxt
index 6ff4f2cdb2..e14fb6d118 100644
--- a/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVars.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVars.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "fake_quant_with_min_max_vars"
- deprecation_message: "tf.fake_quant_with_min_max_vars is deprecated, please use tf.quantization.fake_quant_with_min_max_vars instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsGradient.pbtxt b/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsGradient.pbtxt
index 817a35cc6c..4611ebdfb8 100644
--- a/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsGradient.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsGradient.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "fake_quant_with_min_max_vars_gradient"
- deprecation_message: "tf.fake_quant_with_min_max_vars_gradient is deprecated, please use tf.quantization.fake_quant_with_min_max_vars_gradient instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsPerChannel.pbtxt b/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsPerChannel.pbtxt
index 275c0d5225..0936e513c3 100644
--- a/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsPerChannel.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsPerChannel.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "fake_quant_with_min_max_vars_per_channel"
- deprecation_message: "tf.fake_quant_with_min_max_vars_per_channel is deprecated, please use tf.quantization.fake_quant_with_min_max_vars_per_channel instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsPerChannelGradient.pbtxt b/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsPerChannelGradient.pbtxt
index 897312897f..0d9968248c 100644
--- a/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsPerChannelGradient.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsPerChannelGradient.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "fake_quant_with_min_max_vars_per_channel_gradient"
- deprecation_message: "tf.fake_quant_with_min_max_vars_per_channel_gradient is deprecated, please use tf.quantization.fake_quant_with_min_max_vars_per_channel_gradient instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Floor.pbtxt b/tensorflow/core/api_def/python_api/api_def_Floor.pbtxt
index 788d95edc1..9b93caa0b1 100644
--- a/tensorflow/core/api_def/python_api/api_def_Floor.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Floor.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "floor"
- deprecation_message: "tf.floor is deprecated, please use tf.math.floor instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_GatherNd.pbtxt b/tensorflow/core/api_def/python_api/api_def_GatherNd.pbtxt
index 371dc740df..71257c8855 100644
--- a/tensorflow/core/api_def/python_api/api_def_GatherNd.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_GatherNd.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "gather_nd"
- deprecation_message: "tf.gather_nd is deprecated, please use tf.manip.gather_nd instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Greater.pbtxt b/tensorflow/core/api_def/python_api/api_def_Greater.pbtxt
index c8c56515b2..7de60d44c4 100644
--- a/tensorflow/core/api_def/python_api/api_def_Greater.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Greater.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "greater"
- deprecation_message: "tf.greater is deprecated, please use tf.math.greater instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_GreaterEqual.pbtxt b/tensorflow/core/api_def/python_api/api_def_GreaterEqual.pbtxt
index ccb390fb3e..9c8975c2a9 100644
--- a/tensorflow/core/api_def/python_api/api_def_GreaterEqual.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_GreaterEqual.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "greater_equal"
- deprecation_message: "tf.greater_equal is deprecated, please use tf.math.greater_equal instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_IFFT.pbtxt b/tensorflow/core/api_def/python_api/api_def_IFFT.pbtxt
index 267ad8d0a0..17fbd8ace4 100644
--- a/tensorflow/core/api_def/python_api/api_def_IFFT.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_IFFT.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "ifft"
- deprecation_message: "tf.ifft is deprecated, please use tf.spectral.ifft instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Igamma.pbtxt b/tensorflow/core/api_def/python_api/api_def_Igamma.pbtxt
index 4e7e3a6e57..8c4815c26e 100644
--- a/tensorflow/core/api_def/python_api/api_def_Igamma.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Igamma.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "igamma"
- deprecation_message: "tf.igamma is deprecated, please use tf.math.igamma instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Igammac.pbtxt b/tensorflow/core/api_def/python_api/api_def_Igammac.pbtxt
index ea92a0916b..b43b54391b 100644
--- a/tensorflow/core/api_def/python_api/api_def_Igammac.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Igammac.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "igammac"
- deprecation_message: "tf.igammac is deprecated, please use tf.math.igammac instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_InvertPermutation.pbtxt b/tensorflow/core/api_def/python_api/api_def_InvertPermutation.pbtxt
index bce642b96a..d75fcd63e3 100644
--- a/tensorflow/core/api_def/python_api/api_def_InvertPermutation.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_InvertPermutation.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "invert_permutation"
- deprecation_message: "tf.invert_permutation is deprecated, please use tf.math.invert_permutation instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_IsFinite.pbtxt b/tensorflow/core/api_def/python_api/api_def_IsFinite.pbtxt
index a2c12f2ea0..27142644bf 100644
--- a/tensorflow/core/api_def/python_api/api_def_IsFinite.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_IsFinite.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "is_finite"
- deprecation_message: "tf.is_finite is deprecated, please use tf.debugging.is_finite instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_IsInf.pbtxt b/tensorflow/core/api_def/python_api/api_def_IsInf.pbtxt
index 7c29811fd7..4cd92f1cb7 100644
--- a/tensorflow/core/api_def/python_api/api_def_IsInf.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_IsInf.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "is_inf"
- deprecation_message: "tf.is_inf is deprecated, please use tf.debugging.is_inf instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_IsNan.pbtxt b/tensorflow/core/api_def/python_api/api_def_IsNan.pbtxt
index 459cf3ccbd..07d49f9436 100644
--- a/tensorflow/core/api_def/python_api/api_def_IsNan.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_IsNan.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "is_nan"
- deprecation_message: "tf.is_nan is deprecated, please use tf.debugging.is_nan instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Less.pbtxt b/tensorflow/core/api_def/python_api/api_def_Less.pbtxt
index 15cbdc6d8e..055df2922a 100644
--- a/tensorflow/core/api_def/python_api/api_def_Less.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Less.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "less"
- deprecation_message: "tf.less is deprecated, please use tf.math.less instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_LessEqual.pbtxt b/tensorflow/core/api_def/python_api/api_def_LessEqual.pbtxt
index 35aa18698f..d2803ddb69 100644
--- a/tensorflow/core/api_def/python_api/api_def_LessEqual.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_LessEqual.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "less_equal"
- deprecation_message: "tf.less_equal is deprecated, please use tf.math.less_equal instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Lgamma.pbtxt b/tensorflow/core/api_def/python_api/api_def_Lgamma.pbtxt
index 89886b09d3..0262b838ca 100644
--- a/tensorflow/core/api_def/python_api/api_def_Lgamma.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Lgamma.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "lgamma"
- deprecation_message: "tf.lgamma is deprecated, please use tf.math.lgamma instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Log.pbtxt b/tensorflow/core/api_def/python_api/api_def_Log.pbtxt
index fb82aa7e43..26d2473b9c 100644
--- a/tensorflow/core/api_def/python_api/api_def_Log.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Log.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "log"
- deprecation_message: "tf.log is deprecated, please use tf.math.log instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Log1p.pbtxt b/tensorflow/core/api_def/python_api/api_def_Log1p.pbtxt
index 6b451aa546..d85b6dccec 100644
--- a/tensorflow/core/api_def/python_api/api_def_Log1p.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Log1p.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "log1p"
- deprecation_message: "tf.log1p is deprecated, please use tf.math.log1p instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_LogicalAnd.pbtxt b/tensorflow/core/api_def/python_api/api_def_LogicalAnd.pbtxt
index 403a8c71ff..80bd98b740 100644
--- a/tensorflow/core/api_def/python_api/api_def_LogicalAnd.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_LogicalAnd.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "logical_and"
- deprecation_message: "tf.logical_and is deprecated, please use tf.math.logical_and instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_LogicalNot.pbtxt b/tensorflow/core/api_def/python_api/api_def_LogicalNot.pbtxt
index f228958c77..b2244c44b1 100644
--- a/tensorflow/core/api_def/python_api/api_def_LogicalNot.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_LogicalNot.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "logical_not"
- deprecation_message: "tf.logical_not is deprecated, please use tf.math.logical_not instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_LogicalOr.pbtxt b/tensorflow/core/api_def/python_api/api_def_LogicalOr.pbtxt
index ab89f236e7..cf78b52e07 100644
--- a/tensorflow/core/api_def/python_api/api_def_LogicalOr.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_LogicalOr.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "logical_or"
- deprecation_message: "tf.logical_or is deprecated, please use tf.math.logical_or instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_MatchingFiles.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatchingFiles.pbtxt
index 8930d66940..74145670a8 100644
--- a/tensorflow/core/api_def/python_api/api_def_MatchingFiles.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_MatchingFiles.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "matching_files"
- deprecation_message: "tf.matching_files is deprecated, please use tf.io.matching_files instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_MatrixBandPart.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatrixBandPart.pbtxt
index bad2f03f32..1122c52ab4 100644
--- a/tensorflow/core/api_def/python_api/api_def_MatrixBandPart.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_MatrixBandPart.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "matrix_band_part"
- deprecation_message: "tf.matrix_band_part is deprecated, please use tf.linalg.band_part instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_MatrixDeterminant.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatrixDeterminant.pbtxt
index d241d4d721..9563bf0354 100644
--- a/tensorflow/core/api_def/python_api/api_def_MatrixDeterminant.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_MatrixDeterminant.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "matrix_determinant"
- deprecation_message: "tf.matrix_determinant is deprecated, please use tf.linalg.det instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_MatrixDiag.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatrixDiag.pbtxt
index 208b37e297..8ab0bf75eb 100644
--- a/tensorflow/core/api_def/python_api/api_def_MatrixDiag.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_MatrixDiag.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "matrix_diag"
- deprecation_message: "tf.matrix_diag is deprecated, please use tf.linalg.diag instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_MatrixDiagPart.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatrixDiagPart.pbtxt
index a8a50e8a89..82ce67853c 100644
--- a/tensorflow/core/api_def/python_api/api_def_MatrixDiagPart.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_MatrixDiagPart.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "matrix_diag_part"
- deprecation_message: "tf.matrix_diag_part is deprecated, please use tf.linalg.diag_part instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_MatrixInverse.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatrixInverse.pbtxt
index 944513fcd9..85862f6eb5 100644
--- a/tensorflow/core/api_def/python_api/api_def_MatrixInverse.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_MatrixInverse.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "matrix_inverse"
- deprecation_message: "tf.matrix_inverse is deprecated, please use tf.linalg.inv instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_MatrixSetDiag.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatrixSetDiag.pbtxt
index a6080dbc2d..6325e4f0e6 100644
--- a/tensorflow/core/api_def/python_api/api_def_MatrixSetDiag.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_MatrixSetDiag.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "matrix_set_diag"
- deprecation_message: "tf.matrix_set_diag is deprecated, please use tf.linalg.set_diag instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_MatrixSolve.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatrixSolve.pbtxt
index caba80326b..6325dff407 100644
--- a/tensorflow/core/api_def/python_api/api_def_MatrixSolve.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_MatrixSolve.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "matrix_solve"
- deprecation_message: "tf.matrix_solve is deprecated, please use tf.linalg.solve instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_MatrixTriangularSolve.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatrixTriangularSolve.pbtxt
index a4dfa538ed..7f865e23b2 100644
--- a/tensorflow/core/api_def/python_api/api_def_MatrixTriangularSolve.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_MatrixTriangularSolve.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "matrix_triangular_solve"
- deprecation_message: "tf.matrix_triangular_solve is deprecated, please use tf.linalg.triangular_solve instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Maximum.pbtxt b/tensorflow/core/api_def/python_api/api_def_Maximum.pbtxt
index 90af9e145b..bcff379b71 100644
--- a/tensorflow/core/api_def/python_api/api_def_Maximum.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Maximum.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "maximum"
- deprecation_message: "tf.maximum is deprecated, please use tf.math.maximum instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Minimum.pbtxt b/tensorflow/core/api_def/python_api/api_def_Minimum.pbtxt
index 33bcd6f667..9aae74226a 100644
--- a/tensorflow/core/api_def/python_api/api_def_Minimum.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Minimum.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "minimum"
- deprecation_message: "tf.minimum is deprecated, please use tf.math.minimum instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_NonMaxSuppressionWithOverlaps.pbtxt b/tensorflow/core/api_def/python_api/api_def_NonMaxSuppressionWithOverlaps.pbtxt
new file mode 100644
index 0000000000..0d358dff98
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_NonMaxSuppressionWithOverlaps.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "NonMaxSuppressionWithOverlaps"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_NotEqual.pbtxt b/tensorflow/core/api_def/python_api/api_def_NotEqual.pbtxt
index 385565daaf..f37317854f 100644
--- a/tensorflow/core/api_def/python_api/api_def_NotEqual.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_NotEqual.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "not_equal"
- deprecation_message: "tf.not_equal is deprecated, please use tf.math.not_equal instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_ParseTensor.pbtxt b/tensorflow/core/api_def/python_api/api_def_ParseTensor.pbtxt
index 29f02ab1ac..10b3aab0c7 100644
--- a/tensorflow/core/api_def/python_api/api_def_ParseTensor.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_ParseTensor.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "parse_tensor"
- deprecation_message: "tf.parse_tensor is deprecated, please use tf.io.parse_tensor instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Polygamma.pbtxt b/tensorflow/core/api_def/python_api/api_def_Polygamma.pbtxt
index 567a448642..9df81402d5 100644
--- a/tensorflow/core/api_def/python_api/api_def_Polygamma.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Polygamma.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "polygamma"
- deprecation_message: "tf.polygamma is deprecated, please use tf.math.polygamma instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Qr.pbtxt b/tensorflow/core/api_def/python_api/api_def_Qr.pbtxt
index a9371b5d9b..0260eecc91 100644
--- a/tensorflow/core/api_def/python_api/api_def_Qr.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Qr.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "qr"
- deprecation_message: "tf.qr is deprecated, please use tf.linalg.qr instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_QuantizedConcat.pbtxt b/tensorflow/core/api_def/python_api/api_def_QuantizedConcat.pbtxt
index 44508ef079..69404b9472 100644
--- a/tensorflow/core/api_def/python_api/api_def_QuantizedConcat.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_QuantizedConcat.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "quantized_concat"
- deprecation_message: "tf.quantized_concat is deprecated, please use tf.quantization.quantized_concat instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_ReadFile.pbtxt b/tensorflow/core/api_def/python_api/api_def_ReadFile.pbtxt
index 7c38fae31c..9d479be45f 100644
--- a/tensorflow/core/api_def/python_api/api_def_ReadFile.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_ReadFile.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "read_file"
- deprecation_message: "tf.read_file is deprecated, please use tf.io.read_file instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Reciprocal.pbtxt b/tensorflow/core/api_def/python_api/api_def_Reciprocal.pbtxt
index 0f37e99f4f..c4d4c27722 100644
--- a/tensorflow/core/api_def/python_api/api_def_Reciprocal.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Reciprocal.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "reciprocal"
- deprecation_message: "tf.reciprocal is deprecated, please use tf.math.reciprocal instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_RegexReplace.pbtxt b/tensorflow/core/api_def/python_api/api_def_RegexReplace.pbtxt
index 6938e20e57..b17806b338 100644
--- a/tensorflow/core/api_def/python_api/api_def_RegexReplace.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_RegexReplace.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "regex_replace"
- deprecation_message: "tf.regex_replace is deprecated, please use tf.strings.regex_replace instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Reshape.pbtxt b/tensorflow/core/api_def/python_api/api_def_Reshape.pbtxt
index 907d95a6f0..c469665b66 100644
--- a/tensorflow/core/api_def/python_api/api_def_Reshape.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Reshape.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "reshape"
- deprecation_message: "tf.reshape is deprecated, please use tf.manip.reshape instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_ReverseV2.pbtxt b/tensorflow/core/api_def/python_api/api_def_ReverseV2.pbtxt
index bbe9e97d60..77f595927b 100644
--- a/tensorflow/core/api_def/python_api/api_def_ReverseV2.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_ReverseV2.pbtxt
@@ -5,10 +5,10 @@ op {
}
endpoint {
name: "reverse"
- deprecation_message: "tf.reverse is deprecated, please use tf.manip.reverse instead."
+ deprecated: true
}
endpoint {
name: "reverse_v2"
- deprecation_message: "tf.reverse_v2 is deprecated, please use tf.manip.reverse instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Rint.pbtxt b/tensorflow/core/api_def/python_api/api_def_Rint.pbtxt
index 4330a80d04..ec37a23127 100644
--- a/tensorflow/core/api_def/python_api/api_def_Rint.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Rint.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "rint"
- deprecation_message: "tf.rint is deprecated, please use tf.math.rint instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Rsqrt.pbtxt b/tensorflow/core/api_def/python_api/api_def_Rsqrt.pbtxt
index 6a45f4aff5..4fc2b81421 100644
--- a/tensorflow/core/api_def/python_api/api_def_Rsqrt.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Rsqrt.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "rsqrt"
- deprecation_message: "tf.rsqrt is deprecated, please use tf.math.rsqrt instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_ScatterNd.pbtxt b/tensorflow/core/api_def/python_api/api_def_ScatterNd.pbtxt
index cabf171cb0..a65a19b542 100644
--- a/tensorflow/core/api_def/python_api/api_def_ScatterNd.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_ScatterNd.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "scatter_nd"
- deprecation_message: "tf.scatter_nd is deprecated, please use tf.manip.scatter_nd instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_SegmentMax.pbtxt b/tensorflow/core/api_def/python_api/api_def_SegmentMax.pbtxt
index 65e34a1fcf..2e22c375c0 100644
--- a/tensorflow/core/api_def/python_api/api_def_SegmentMax.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_SegmentMax.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "segment_max"
- deprecation_message: "tf.segment_max is deprecated, please use tf.math.segment_max instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_SegmentMean.pbtxt b/tensorflow/core/api_def/python_api/api_def_SegmentMean.pbtxt
index f1e19c5571..646348072f 100644
--- a/tensorflow/core/api_def/python_api/api_def_SegmentMean.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_SegmentMean.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "segment_mean"
- deprecation_message: "tf.segment_mean is deprecated, please use tf.math.segment_mean instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_SegmentMin.pbtxt b/tensorflow/core/api_def/python_api/api_def_SegmentMin.pbtxt
index fd9a3c380d..1a77019a2d 100644
--- a/tensorflow/core/api_def/python_api/api_def_SegmentMin.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_SegmentMin.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "segment_min"
- deprecation_message: "tf.segment_min is deprecated, please use tf.math.segment_min instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_SegmentProd.pbtxt b/tensorflow/core/api_def/python_api/api_def_SegmentProd.pbtxt
index f2be8baafc..cf4d6f0237 100644
--- a/tensorflow/core/api_def/python_api/api_def_SegmentProd.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_SegmentProd.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "segment_prod"
- deprecation_message: "tf.segment_prod is deprecated, please use tf.math.segment_prod instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_SegmentSum.pbtxt b/tensorflow/core/api_def/python_api/api_def_SegmentSum.pbtxt
index c7cc1d0c9f..c6d7999455 100644
--- a/tensorflow/core/api_def/python_api/api_def_SegmentSum.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_SegmentSum.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "segment_sum"
- deprecation_message: "tf.segment_sum is deprecated, please use tf.math.segment_sum instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Sin.pbtxt b/tensorflow/core/api_def/python_api/api_def_Sin.pbtxt
index 0794334987..9c19a1a177 100644
--- a/tensorflow/core/api_def/python_api/api_def_Sin.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Sin.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "sin"
- deprecation_message: "tf.sin is deprecated, please use tf.math.sin instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Sinh.pbtxt b/tensorflow/core/api_def/python_api/api_def_Sinh.pbtxt
index c42f8678c6..155e58e6d5 100644
--- a/tensorflow/core/api_def/python_api/api_def_Sinh.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Sinh.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "sinh"
- deprecation_message: "tf.sinh is deprecated, please use tf.math.sinh instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_SpaceToBatchND.pbtxt b/tensorflow/core/api_def/python_api/api_def_SpaceToBatchND.pbtxt
index 63a7547e14..af323a6cf3 100644
--- a/tensorflow/core/api_def/python_api/api_def_SpaceToBatchND.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_SpaceToBatchND.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "space_to_batch_nd"
- deprecation_message: "tf.space_to_batch_nd is deprecated, please use tf.manip.space_to_batch_nd instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_SquaredDifference.pbtxt b/tensorflow/core/api_def/python_api/api_def_SquaredDifference.pbtxt
index 01a33a3346..4bab8cf00c 100644
--- a/tensorflow/core/api_def/python_api/api_def_SquaredDifference.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_SquaredDifference.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "squared_difference"
- deprecation_message: "tf.squared_difference is deprecated, please use tf.math.squared_difference instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_StringJoin.pbtxt b/tensorflow/core/api_def/python_api/api_def_StringJoin.pbtxt
index 53c1b8053d..46a7c0361e 100644
--- a/tensorflow/core/api_def/python_api/api_def_StringJoin.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_StringJoin.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "string_join"
- deprecation_message: "tf.string_join is deprecated, please use tf.strings.join instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_StringStrip.pbtxt b/tensorflow/core/api_def/python_api/api_def_StringStrip.pbtxt
index 364806e1f5..fbcdeaad6d 100644
--- a/tensorflow/core/api_def/python_api/api_def_StringStrip.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_StringStrip.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "string_strip"
- deprecation_message: "tf.string_strip is deprecated, please use tf.strings.strip instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_StringToHashBucket.pbtxt b/tensorflow/core/api_def/python_api/api_def_StringToHashBucket.pbtxt
index b0e93d2b22..d122e79b39 100644
--- a/tensorflow/core/api_def/python_api/api_def_StringToHashBucket.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_StringToHashBucket.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "string_to_hash_bucket"
- deprecation_message: "tf.string_to_hash_bucket is deprecated, please use tf.strings.to_hash_bucket instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_StringToHashBucketFast.pbtxt b/tensorflow/core/api_def/python_api/api_def_StringToHashBucketFast.pbtxt
index 9576e1a9de..aef9dffefe 100644
--- a/tensorflow/core/api_def/python_api/api_def_StringToHashBucketFast.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_StringToHashBucketFast.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "string_to_hash_bucket_fast"
- deprecation_message: "tf.string_to_hash_bucket_fast is deprecated, please use tf.strings.to_hash_bucket_fast instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_StringToHashBucketStrong.pbtxt b/tensorflow/core/api_def/python_api/api_def_StringToHashBucketStrong.pbtxt
index e8c7c12608..385b9fd02a 100644
--- a/tensorflow/core/api_def/python_api/api_def_StringToHashBucketStrong.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_StringToHashBucketStrong.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "string_to_hash_bucket_strong"
- deprecation_message: "tf.string_to_hash_bucket_strong is deprecated, please use tf.strings.to_hash_bucket_strong instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_StringToNumber.pbtxt b/tensorflow/core/api_def/python_api/api_def_StringToNumber.pbtxt
index 9de1ca0b30..f740b9849d 100644
--- a/tensorflow/core/api_def/python_api/api_def_StringToNumber.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_StringToNumber.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "string_to_number"
- deprecation_message: "tf.string_to_number is deprecated, please use tf.strings.to_number instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Substr.pbtxt b/tensorflow/core/api_def/python_api/api_def_Substr.pbtxt
index 25d1bb3f51..4778d7927c 100644
--- a/tensorflow/core/api_def/python_api/api_def_Substr.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Substr.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "substr"
- deprecation_message: "tf.substr is deprecated, please use tf.strings.substr instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Tan.pbtxt b/tensorflow/core/api_def/python_api/api_def_Tan.pbtxt
index 8bcf381dd4..ffa92f5580 100644
--- a/tensorflow/core/api_def/python_api/api_def_Tan.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Tan.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "tan"
- deprecation_message: "tf.tan is deprecated, please use tf.math.tan instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Tile.pbtxt b/tensorflow/core/api_def/python_api/api_def_Tile.pbtxt
index 0b9053a529..c34061c941 100644
--- a/tensorflow/core/api_def/python_api/api_def_Tile.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Tile.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "tile"
- deprecation_message: "tf.tile is deprecated, please use tf.manip.tile instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentMax.pbtxt b/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentMax.pbtxt
index 1ea59d2e63..cf81843241 100644
--- a/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentMax.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentMax.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "unsorted_segment_max"
- deprecation_message: "tf.unsorted_segment_max is deprecated, please use tf.math.unsorted_segment_max instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentMin.pbtxt b/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentMin.pbtxt
index 9857def6fe..475361c85a 100644
--- a/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentMin.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentMin.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "unsorted_segment_min"
- deprecation_message: "tf.unsorted_segment_min is deprecated, please use tf.math.unsorted_segment_min instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentProd.pbtxt b/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentProd.pbtxt
index d9e3f7be69..a9d741bbc3 100644
--- a/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentProd.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentProd.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "unsorted_segment_prod"
- deprecation_message: "tf.unsorted_segment_prod is deprecated, please use tf.math.unsorted_segment_prod instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentSum.pbtxt b/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentSum.pbtxt
index 0cffd12404..337678dcff 100644
--- a/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentSum.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentSum.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "unsorted_segment_sum"
- deprecation_message: "tf.unsorted_segment_sum is deprecated, please use tf.math.unsorted_segment_sum instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_WriteFile.pbtxt b/tensorflow/core/api_def/python_api/api_def_WriteFile.pbtxt
index f28a9151ca..1a58ae19e5 100644
--- a/tensorflow/core/api_def/python_api/api_def_WriteFile.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_WriteFile.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "write_file"
- deprecation_message: "tf.write_file is deprecated, please use tf.io.write_file instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Zeta.pbtxt b/tensorflow/core/api_def/python_api/api_def_Zeta.pbtxt
index a84ffcdf14..4684a9d624 100644
--- a/tensorflow/core/api_def/python_api/api_def_Zeta.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Zeta.pbtxt
@@ -5,6 +5,6 @@ op {
}
endpoint {
name: "zeta"
- deprecation_message: "tf.zeta is deprecated, please use tf.math.zeta instead."
+ deprecated: true
}
}
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index f903faf1bd..4c670820be 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -146,18 +146,15 @@ class DirectSessionFactory : public SessionFactory {
return options.target.empty();
}
- Session* NewSession(const SessionOptions& options) override {
+ Status NewSession(const SessionOptions& options,
+ Session** out_session) override {
// Must do this before the CPU allocator is created.
if (options.config.graph_options().build_cost_model() > 0) {
EnableCPUAllocatorFullStats(true);
}
std::vector<Device*> devices;
- const Status s = DeviceFactory::AddDevices(
- options, "/job:localhost/replica:0/task:0", &devices);
- if (!s.ok()) {
- LOG(ERROR) << s;
- return nullptr;
- }
+ TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(
+ options, "/job:localhost/replica:0/task:0", &devices));
DirectSession* session =
new DirectSession(options, new DeviceMgr(devices), this);
@@ -165,7 +162,8 @@ class DirectSessionFactory : public SessionFactory {
mutex_lock l(sessions_lock_);
sessions_.push_back(session);
}
- return session;
+ *out_session = session;
+ return Status::OK();
}
Status Reset(const SessionOptions& options,
@@ -1188,12 +1186,11 @@ Status DirectSession::CreateExecutors(
delete kernel;
}
};
- params.node_outputs_cb = node_outputs_callback_;
optimizer.Optimize(lib, options_.env, device, &iter->second,
/*shape_map=*/nullptr);
- // EXPERIMENTAL: tfdbg inserts debug nodes in the graph.
+ // TensorFlow Debugger (tfdbg) inserts debug nodes in the graph.
const DebugOptions& debug_options =
options.callable_options.run_options().debug_options();
if (!debug_options.debug_tensor_watch_opts().empty()) {
diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc
index 5b424230ca..142d613129 100644
--- a/tensorflow/core/common_runtime/direct_session_test.cc
+++ b/tensorflow/core/common_runtime/direct_session_test.cc
@@ -1714,8 +1714,7 @@ TEST(DirectSessionTest, LocalDeviceManager) {
// y = tf.square(x)
GraphDef CreateGraphForYEqualsXSquared() {
GraphDef graph_def;
- QCHECK(protobuf::TextFormat::ParseFromString(
- R"EOF(
+ const char* text_proto = R"EOF(
node {
name: "x"
op: "Placeholder"
@@ -1731,8 +1730,9 @@ node {
versions {
producer: 26
}
- )EOF",
- &graph_def));
+ )EOF";
+
+ QCHECK(protobuf::TextFormat::ParseFromString(text_proto, &graph_def));
return graph_def;
}
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index f7f2cdc14f..8096139d90 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -1966,17 +1966,9 @@ Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
device_context = device_context_map_[node->id()];
}
- // Experimental: debugger (tfdb) access to intermediate node completion.
- if (item.num_outputs == 0 && impl_->params_.node_outputs_cb != nullptr) {
- // If the node has no output, invoke the callback with output slot set to
- // -1, signifying that this is a no-output node.
- s.Update(impl_->params_.node_outputs_cb(item.node->name(), -1, nullptr,
- false, ctx));
- }
-
for (int i = 0; i < item.num_outputs; ++i) {
const TensorValue val = ctx->release_output(i);
- if (*ctx->is_output_dead() || val.tensor == nullptr) {
+ if (val.tensor == nullptr) {
// Unless it's a Switch or a Recv, the node must produce a
// tensor value at i-th output.
if (!IsSwitch(node) && !IsRecv(node)) {
@@ -2018,13 +2010,6 @@ Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
LogMemory::RecordTensorOutput(ctx->op_kernel().name(),
ctx->step_id(), i, to_log);
}
-
- // Experimental: debugger (tfdb) access to intermediate node
- // outputs.
- if (impl_->params_.node_outputs_cb != nullptr) {
- s.Update(impl_->params_.node_outputs_cb(item.node->name(), i,
- out->ref, true, ctx));
- }
} else {
// NOTE that std::move is used here, so val.tensor goes to
// uninitialized state (val.tensor->IsInitialized return false).
@@ -2036,12 +2021,6 @@ Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
LogMemory::RecordTensorOutput(ctx->op_kernel().name(),
ctx->step_id(), i, *out->val);
}
-
- // Experimental: debugger access to intermediate node outputs.
- if (impl_->params_.node_outputs_cb != nullptr) {
- s.Update(impl_->params_.node_outputs_cb(
- item.node->name(), i, out->val.get(), false, ctx));
- }
}
} else {
s.Update(errors::Internal("Output ", i, " of type ",
diff --git a/tensorflow/core/common_runtime/executor.h b/tensorflow/core/common_runtime/executor.h
index e5d7b7c53c..cd01b43aea 100644
--- a/tensorflow/core/common_runtime/executor.h
+++ b/tensorflow/core/common_runtime/executor.h
@@ -103,7 +103,6 @@ class Executor {
const Tensor* tensor, const bool is_ref,
OpKernelContext* ctx)>
NodeOutputsCallback;
- NodeOutputsCallback node_outputs_cb = nullptr;
};
typedef std::function<void(const Status&)> DoneCallback;
virtual void RunAsync(const Args& args, DoneCallback done) = 0;
@@ -139,8 +138,6 @@ struct LocalExecutorParams {
// when the executor is deleted.
std::function<Status(const NodeDef&, OpKernel**)> create_kernel;
std::function<void(OpKernel*)> delete_kernel;
-
- Executor::Args::NodeOutputsCallback node_outputs_cb;
};
::tensorflow::Status NewLocalExecutor(const LocalExecutorParams& params,
std::unique_ptr<const Graph> graph,
diff --git a/tensorflow/core/common_runtime/placer.cc b/tensorflow/core/common_runtime/placer.cc
index 86851c2c07..1f0773d387 100644
--- a/tensorflow/core/common_runtime/placer.cc
+++ b/tensorflow/core/common_runtime/placer.cc
@@ -628,6 +628,40 @@ class ColocationGraph {
return parent;
}
+ // Ensures that the devices of 'dst's resource and reference match the device
+ // specified for 'src', which is an input of 'dst' with a partially or fully
+ // specified device.
+ Status VerifyResourceAndRefInputsCanBeColocated(
+ const Node* dst, const Node* src,
+ const DeviceNameUtils::ParsedName& src_parsed_name) {
+ std::vector<const Edge*> edges;
+ TF_RETURN_IF_ERROR(dst->input_edges(&edges));
+ for (const Edge* edge : edges) {
+ DataType input_type = dst->input_type(edge->dst_input());
+ if (input_type == DT_RESOURCE || IsRefType(input_type)) {
+ const Node* input_node = edge->src();
+ if (input_node == src) {
+ continue;
+ }
+ const auto& input_root = members_[FindRoot(input_node->id())];
+ const auto& input_parsed_name = input_root.device_name;
+ if (DeviceNameUtils::HasSomeDetails(input_parsed_name) &&
+ !DeviceNameUtils::AreCompatibleDevNames(input_parsed_name,
+ src_parsed_name)) {
+ return AttachDef(
+ errors::InvalidArgument(
+ "Could not colocate node with its "
+ "resource and reference inputs; devices ",
+ DeviceNameUtils::ParsedNameToString(input_parsed_name),
+ " and ", DeviceNameUtils::ParsedNameToString(src_parsed_name),
+ " are not compatible."),
+ *dst);
+ }
+ }
+ }
+ return Status::OK();
+ }
+
Graph* const graph_; // Not owned.
std::vector<Member> members_;
const DeviceSet* device_set_; // Not owned.
@@ -646,6 +680,15 @@ bool IsGeneratorNode(const Node* node) {
!IsRefType(node->output_type(0));
}
+bool IsExemptFromResourceInputColocation(const Node* node) {
+ // Note: Partitioned function calls, which place and partition their
+ // function bodies, are exempt from this check: they forward resource and
+ // ref inputs to operations that are appropriately placed, instead of
+ // dereferencing them.
+ const string& op_type = node->op_def().name();
+ return op_type == "PartitionedCall" || op_type == "StatefulPartitionedCall";
+}
+
} // namespace
Placer::Placer(Graph* graph, const DeviceSet* devices,
@@ -680,8 +723,8 @@ Status Placer::Run() {
// 2. Enumerate the constraint edges, and use them to update the disjoint
// node set.
- // If `node` has an input edge with reference type, add an
- // edge from the source of that edge to `node`.
+ // If `node` has an input edge with reference type, add an edge from the
+ // source of that edge to `node`.
for (const Edge* edge : graph_->edges()) {
if (edge->IsControlEdge()) {
continue;
@@ -689,7 +732,10 @@ Status Placer::Run() {
Node* src = edge->src();
Node* dst = edge->dst();
DataType input_type = dst->input_type(edge->dst_input());
- if (input_type == DT_RESOURCE || IsRefType(input_type)) {
+ if ((input_type == DT_RESOURCE || IsRefType(input_type)) &&
+ !IsExemptFromResourceInputColocation(dst)) {
+ // Colocate `src` and `dst` to maintain the invariant that nodes connected
+ // by reference edges are colocated.
int src_root_id = colocation_graph.FindRoot(src->id());
int dst_root_id = colocation_graph.FindRoot(dst->id());
auto& src_root = colocation_graph.members_[src_root_id];
@@ -706,6 +752,9 @@ Status Placer::Run() {
// incompatible.
if (!DeviceNameUtils::AreCompatibleDevNames(source_parsed_name,
dest_parsed_name)) {
+ TF_RETURN_IF_ERROR(
+ colocation_graph.VerifyResourceAndRefInputsCanBeColocated(
+ dst, src, source_parsed_name));
if (log_device_placement_) {
LOG(INFO) << "Ignoring device specification "
<< DeviceNameUtils::ParsedNameToString(dest_parsed_name)
diff --git a/tensorflow/core/common_runtime/placer_test.cc b/tensorflow/core/common_runtime/placer_test.cc
index 5ad251c892..07a7724f16 100644
--- a/tensorflow/core/common_runtime/placer_test.cc
+++ b/tensorflow/core/common_runtime/placer_test.cc
@@ -575,6 +575,10 @@ REGISTER_KERNEL_BUILDER(Name("HandleAssignCPU").Device("FakeCPU"), DummyOp);
REGISTER_OP("HandleAssignGPU").Input("i: resource").Input("v: float");
REGISTER_KERNEL_BUILDER(Name("HandleAssignGPU").Device("FakeGPU"), DummyOp);
+REGISTER_OP("TestTwoHandlesIn").Input("i: resource").Input("j: resource");
+REGISTER_KERNEL_BUILDER(Name("TestTwoHandlesIn").Device("FakeCPU"), DummyOp);
+REGISTER_KERNEL_BUILDER(Name("TestTwoHandlesIn").Device("FakeGPU"), DummyOp);
+
// Tests all combinations of resource handles and ops using them.
TEST_F(PlacerTest, TestResourceHandle) {
auto handle_test = [this](const string& var_op_name,
@@ -609,6 +613,42 @@ TEST_F(PlacerTest, TestResourceHandle) {
handle_test("HandleVariableCPU", "HandleAssignGPU", "FakeCPU").ok());
}
+TEST_F(PlacerTest, TestResourceHandlesOnDifferentDevicesFails) {
+ auto handle_test = [this](bool allow_soft_placement) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ Node* var_cpu =
+ ops::SourceOp("TestHandleVariable", b.opts().WithName("var_cpu"));
+ Node* var_gpu =
+ ops::SourceOp("TestHandleVariable", b.opts().WithName("var_gpu"));
+ ops::BinaryOp("TestTwoHandlesIn", var_cpu, var_gpu,
+ b.opts().WithName("two_handles_in"));
+ TF_EXPECT_OK(BuildGraph(b, &g));
+
+ GetNodeByName(g, "var_cpu")
+ ->set_assigned_device_name(
+ "/job:a/replica:0/task:0/device:fakecpu:0");
+ GetNodeByName(g, "var_gpu")
+ ->set_assigned_device_name(
+ "/job:a/replica:0/task:0/device:fakegpu:0");
+ }
+
+ SessionOptions options;
+ options.config.set_allow_soft_placement(allow_soft_placement);
+ options.config.set_log_device_placement(true);
+ Status s = Place(&g, &options);
+ EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
+ EXPECT_TRUE(str_util::StrContains(
+ s.error_message(),
+ "Could not colocate node with its resource and reference inputs"));
+ return Status::OK();
+ };
+
+ TF_EXPECT_OK(handle_test(false));
+ TF_EXPECT_OK(handle_test(true));
+}
+
// Test that an assignment of an operator to the wrong device
// is ignored when it could never be satisfied (due to reference
// edges, for example).
diff --git a/tensorflow/core/common_runtime/session.cc b/tensorflow/core/common_runtime/session.cc
index 4a9248171b..8c30beeec2 100644
--- a/tensorflow/core/common_runtime/session.cc
+++ b/tensorflow/core/common_runtime/session.cc
@@ -53,27 +53,33 @@ Status Session::PRun(const string& handle,
Session* NewSession(const SessionOptions& options) {
SessionFactory* factory;
- const Status s = SessionFactory::GetFactory(options, &factory);
+ Status s = SessionFactory::GetFactory(options, &factory);
if (!s.ok()) {
LOG(ERROR) << s;
return nullptr;
}
- return factory->NewSession(options);
+ Session* out_session;
+ s = NewSession(options, &out_session);
+ if (!s.ok()) {
+ LOG(ERROR) << "Failed to create session: " << s;
+ return nullptr;
+ }
+ return out_session;
}
Status NewSession(const SessionOptions& options, Session** out_session) {
SessionFactory* factory;
- const Status s = SessionFactory::GetFactory(options, &factory);
+ Status s = SessionFactory::GetFactory(options, &factory);
if (!s.ok()) {
*out_session = nullptr;
LOG(ERROR) << s;
return s;
}
- *out_session = factory->NewSession(options);
- if (!*out_session) {
- return errors::Internal("Failed to create session.");
+ s = factory->NewSession(options, out_session);
+ if (!s.ok()) {
+ *out_session = nullptr;
}
- return Status::OK();
+ return s;
}
Status Reset(const SessionOptions& options,
diff --git a/tensorflow/core/common_runtime/session_factory.h b/tensorflow/core/common_runtime/session_factory.h
index df3198a70d..81c172c6ae 100644
--- a/tensorflow/core/common_runtime/session_factory.h
+++ b/tensorflow/core/common_runtime/session_factory.h
@@ -30,7 +30,12 @@ struct SessionOptions;
class SessionFactory {
public:
- virtual Session* NewSession(const SessionOptions& options) = 0;
+ // Creates a new session and stores it in *out_session, or fails with an error
+ // status if the Session could not be created. Caller takes ownership of
+ // *out_session if this returns Status::OK().
+ virtual Status NewSession(const SessionOptions& options,
+ Session** out_session) = 0;
+
virtual bool AcceptsOptions(const SessionOptions& options) = 0;
// Abort and close all existing sessions, disconnecting their resources from
diff --git a/tensorflow/core/common_runtime/session_test.cc b/tensorflow/core/common_runtime/session_test.cc
index feaf29c7bb..1fa5aad60c 100644
--- a/tensorflow/core/common_runtime/session_test.cc
+++ b/tensorflow/core/common_runtime/session_test.cc
@@ -47,8 +47,10 @@ class FakeSessionFactory : public SessionFactory {
return str_util::StartsWith(options.target, "fake");
}
- Session* NewSession(const SessionOptions& options) override {
- return nullptr;
+ Status NewSession(const SessionOptions& options,
+ Session** out_session) override {
+ *out_session = nullptr;
+ return Status::OK();
}
};
class FakeSessionRegistrar {
diff --git a/tensorflow/core/debug/BUILD b/tensorflow/core/debug/BUILD
index 36e9b3455a..591c22b8f6 100644
--- a/tensorflow/core/debug/BUILD
+++ b/tensorflow/core/debug/BUILD
@@ -82,25 +82,6 @@ cc_library(
)
tf_cuda_library(
- name = "debug_gateway_internal",
- srcs = ["debug_gateway.cc"],
- hdrs = ["debug_gateway.h"],
- copts = tf_copts(),
- linkstatic = 1,
- deps = [
- ":debug",
- "//tensorflow/core:core_cpu_internal",
- "//tensorflow/core:direct_session_internal",
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
- "//tensorflow/core:proto_text",
- "//tensorflow/core:protos_all_cc",
- ],
- alwayslink = 1,
-)
-
-tf_cuda_library(
name = "debugger_state_impl",
srcs = ["debugger_state_impl.cc"],
hdrs = ["debugger_state_impl.h"],
@@ -187,42 +168,6 @@ tf_cuda_library(
],
)
-# TODO(cais): Fix flakiness on GPU and change this back to a tf_cc_test_gpu.
-# See b/34081273.
-tf_cc_test(
- name = "debug_gateway_test",
- size = "small",
- srcs = ["debug_gateway_test.cc"],
- args = ["--heap_check=local"],
- linkstatic = tf_kernel_tests_linkstatic(),
- tags = [
- "no_cuda_on_cpu_tap",
- "no_gpu",
- ],
- deps = [
- ":debug",
- ":debug_gateway_internal",
- ":debug_graph_utils",
- "//tensorflow/cc:cc_ops",
- "//tensorflow/core:all_kernels",
- "//tensorflow/core:core_cpu",
- "//tensorflow/core:core_cpu_internal",
- "//tensorflow/core:direct_session",
- "//tensorflow/core:direct_session_internal",
- "//tensorflow/core:framework",
- "//tensorflow/core:framework_internal",
- "//tensorflow/core:gpu_runtime",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
- "//tensorflow/core:protos_all_cc",
- "//tensorflow/core:test",
- "//tensorflow/core:test_main",
- "//tensorflow/core:testlib",
- "//tensorflow/core/kernels:debug_ops",
- "//tensorflow/core/kernels:ops_util",
- ],
-)
-
tf_cc_test(
name = "debug_io_utils_test",
size = "small",
diff --git a/tensorflow/core/debug/debug_gateway.cc b/tensorflow/core/debug/debug_gateway.cc
deleted file mode 100644
index 2e1aabd1cc..0000000000
--- a/tensorflow/core/debug/debug_gateway.cc
+++ /dev/null
@@ -1,122 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/debug/debug_gateway.h"
-
-#include <utility>
-
-#include "tensorflow/core/common_runtime/device_factory.h"
-#include "tensorflow/core/common_runtime/session_factory.h"
-#include "tensorflow/core/framework/tensor.h"
-
-namespace tensorflow {
-
-DebugGateway::DebugGateway(DirectSession* session) : session_(session) {
- session_->node_outputs_callback_ =
- [this](const string& node_name, const int output_slot,
- const Tensor* tensor, const bool is_ref, OpKernelContext* ctx) {
- if (comp_cb_ != nullptr && output_slot <= 0) {
- // The node completion callback is invoked once for a node regardless
- // of whether the node has zero, one or more outputs.
- // The output_slot can be negative (-1, or kControlSlot) if
- // node_outputs_callback_ is invoked for a node with no output. If
- // that is the case, notify the callback that the node in question has
- // no output.
- comp_cb_(node_name, output_slot == 0);
- }
-
- // Copy tensor values (e.g., from GPU to host) only if the
- // value callback is not nullptr.
- if (val_cb_ != nullptr && output_slot >= 0) {
- CopyTensor(node_name, output_slot, tensor, ctx,
- [this, node_name, output_slot,
- is_ref](const Tensor* copied_tensor) {
- val_cb_(node_name, output_slot, *copied_tensor, is_ref);
- });
- }
-
- return Status::OK();
- };
-}
-
-DebugGateway::~DebugGateway() {
- if (session_ != nullptr) {
- session_->node_outputs_callback_ = nullptr;
- }
-}
-
-void DebugGateway::SetNodeCompletionCallback(NodeCompletionCallback callback) {
- comp_cb_ = std::move(callback);
-}
-
-void DebugGateway::SetNodeValueCallback(NodeValueCallback callback) {
- val_cb_ = std::move(callback);
-}
-
-void DebugGateway::CopyTensor(const string& node_name, const int output_slot,
- const Tensor* src_tensor, OpKernelContext* ctx,
- CopyDoneCallback copy_done_cb) {
- Device* device = static_cast<Device*>(ctx->device());
-
- // Determine if the tensor is initialized properly.
- // The second part of the check is necessary because in some cases, a
- // tensor can pass the IsInitialized() check, but the dtype is not set,
- // e.g., tf.FIFOQueue.
- if (src_tensor->IsInitialized() && DataTypeSize(src_tensor->dtype()) > 0) {
- // Tensor is initialized.
-
- string tensor_tag = strings::StrCat(node_name, ":", output_slot);
-
- // Create copied tensor on host
- Allocator* cpu_allocator = tensorflow::cpu_allocator();
- Tensor cpu_tensor(cpu_allocator, src_tensor->dtype(), src_tensor->shape());
-
- // Determine if the tensor is on device (GPU) or host (CPU).
- // The second part of the check is necessary because even an OpKernel on
- // may have output tensors allocated on CPU.
- if ((device->name().find("GPU:") != string::npos ||
- device->name().find("SYCL:") != string::npos) &&
- !ctx->output_alloc_attr(output_slot).on_host()) {
- // GPU tensors: Copy it to host (CPU).
- DeviceContext* device_ctxt = ctx->op_device_context();
-
- // Copy device (e.g., GPU) tensor to host and when done, invoke the
- // callback.
- device_ctxt->CopyDeviceTensorToCPU(
- src_tensor, "TensorCopy", device, &cpu_tensor,
- [node_name, cpu_tensor, copy_done_cb](const Status& s) {
- if (s.ok()) {
- copy_done_cb(&cpu_tensor);
- } else {
- LOG(ERROR) << "Copying of device Tensor " << node_name
- << " to CPU for debugging failed.";
- }
- });
- } else {
- // For CPU tensors, copy the source tensor and own the copy, because the
- // value callback may outlive the life time of the tensor and the tensor
- // may shared the underlying buffer with other tensors.
- cpu_tensor.UnsafeCopyFromInternal(*src_tensor, src_tensor->dtype(),
- src_tensor->shape());
-
- copy_done_cb(&cpu_tensor);
- }
- } else {
- // Tensor is not initialized: No need to copy.
- copy_done_cb(src_tensor);
- }
-}
-
-} // namespace tensorflow
diff --git a/tensorflow/core/debug/debug_gateway.h b/tensorflow/core/debug/debug_gateway.h
deleted file mode 100644
index bf5b6e08db..0000000000
--- a/tensorflow/core/debug/debug_gateway.h
+++ /dev/null
@@ -1,83 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_DEBUG_DEBUG_SESSION_H_
-#define TENSORFLOW_DEBUG_DEBUG_SESSION_H_
-
-#include <unordered_map>
-
-#include "tensorflow/core/common_runtime/direct_session.h"
-#include "tensorflow/core/common_runtime/executor.h"
-
-namespace tensorflow {
-
-// Experimental. tfdb (TensorFlow Debugger): Gateway to intermediate node
-// outputs during Session Run calls. Currently limited to DirectSession.
-class DebugGateway {
- public:
- DebugGateway(DirectSession* session);
- virtual ~DebugGateway();
-
- // Callback for node completion. This callback is invoked only once for
- // a node regardless of whether it has one or more outputs. The value(s) of
- // the output tensor(s) are not necessarily available when this callback is
- // invoked. They may need to be asynchronously copied from device (e.g.,
- // GPU) to host, hence the need for the NodeValueCallback below.
- //
- // Args:
- // node_name: Name of the node that has just completed execution
- // any_output: Whether the node has any output(s)
- typedef std::function<void(const string& node_name, const bool any_output)>
- NodeCompletionCallback;
- void SetNodeCompletionCallback(NodeCompletionCallback callback);
-
- // Callback for node value. This is invoked when the value of a node's
- // output tensor is available on the host, possibly after copying from
- // a device (e.g., GPU).
- //
- // Args:
- // node_name: Name of the node of which the output has become available
- // output_slot: Output slot number of the output Tensor
- // tensor_value: Reference to the tensor value
- // is_ref: Whether the output of the reference type
- typedef std::function<void(const string& node_name, const int output_slot,
- const Tensor& tensor_value, const bool is_ref)>
- NodeValueCallback;
- void SetNodeValueCallback(NodeValueCallback callback);
-
- // TODO(cais): Add whitelists for ops/tensors (e.g., {"A:0", "B:0"})
- // for node completion callback (whitelist_comp_) and node value callback
- // (whitelist_val_). If whitelist_comp_ is non-empty, the gateway will
- // invoke the NodeCompletionCallback only for the nodes specified in the
- // whitelist. And so forth for whitelist_val_.
-
- private:
- DirectSession* session_;
- // TODO(cais): DebugGateway currently supports only DirectSession. Add
- // support for GrpcSession.
-
- NodeCompletionCallback comp_cb_ = nullptr;
- NodeValueCallback val_cb_ = nullptr;
-
- typedef std::function<void(const Tensor* dst_tensor)> CopyDoneCallback;
-
- void CopyTensor(const string& node_name, const int output_slot,
- const Tensor* src_tensor, OpKernelContext* ctx,
- CopyDoneCallback copy_done_cb);
-};
-
-} // end namespace tensorflow
-
-#endif // TENSORFLOW_DEBUG_DEBUG_SESSION_H_
diff --git a/tensorflow/core/debug/debug_gateway_test.cc b/tensorflow/core/debug/debug_gateway_test.cc
deleted file mode 100644
index b1bbd3f698..0000000000
--- a/tensorflow/core/debug/debug_gateway_test.cc
+++ /dev/null
@@ -1,1011 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/debug/debug_gateway.h"
-
-#include <algorithm>
-#include <cstdlib>
-#include <memory>
-#include <unordered_map>
-
-#include "tensorflow/core/debug/debug_graph_utils.h"
-#include "tensorflow/core/framework/tensor_testutil.h"
-#include "tensorflow/core/graph/testlib.h"
-#include "tensorflow/core/lib/core/notification.h"
-#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/core/threadpool.h"
-#include "tensorflow/core/protobuf/rewriter_config.pb.h"
-
-namespace tensorflow {
-namespace {
-
-std::unique_ptr<DirectSession> CreateSession() {
- SessionOptions options;
- // Turn off graph optimizer so we can observe intermediate node states.
- options.config.mutable_graph_options()
- ->mutable_optimizer_options()
- ->set_opt_level(OptimizerOptions_Level_L0);
- options.config.mutable_graph_options()
- ->mutable_rewrite_options()
- ->set_constant_folding(RewriterConfig::OFF);
- options.config.mutable_graph_options()
- ->mutable_rewrite_options()
- ->set_dependency_optimization(RewriterConfig::OFF);
-
- return std::unique_ptr<DirectSession>(
- dynamic_cast<DirectSession*>(NewSession(options)));
-}
-
-class SessionDebugMinusAXTest : public ::testing::Test {
- public:
- void Initialize(std::initializer_list<float> a_values) {
- Graph graph(OpRegistry::Global());
-
-#if GOOGLE_CUDA
- const string kDeviceName = "/job:localhost/replica:0/task:0/device:GPU:0";
-#elif defined(TENSORFLOW_USE_SYCL)
- const string kDeviceName = "/job:localhost/replica:0/task:0/device:SYCL:0";
-#else
- const string kDeviceName = "/job:localhost/replica:0/task:0/device:CPU:0";
-#endif
-
- Tensor a_tensor(DT_FLOAT, TensorShape({2, 2}));
- test::FillValues<float>(&a_tensor, a_values);
- Node* a = test::graph::Constant(&graph, a_tensor);
- a->set_assigned_device_name(kDeviceName);
- a_ = a->name();
-
- Tensor x_tensor(DT_FLOAT, TensorShape({2, 1}));
- test::FillValues<float>(&x_tensor, {1, 1});
- Node* x = test::graph::Constant(&graph, x_tensor);
- x->set_assigned_device_name(kDeviceName);
- x_ = x->name();
-
- // y = A * x
- Node* y = test::graph::Matmul(&graph, a, x, false, false);
- y->set_assigned_device_name(kDeviceName);
- y_ = y->name();
-
- Node* y_neg = test::graph::Unary(&graph, "Neg", y);
- y_neg_ = y_neg->name();
- y_neg->set_assigned_device_name(kDeviceName);
-
- test::graph::ToGraphDef(&graph, &def_);
- }
-
- string a_;
- string x_;
- string y_;
- string y_neg_;
- GraphDef def_;
-};
-
-TEST_F(SessionDebugMinusAXTest, RunSimpleNetwork) {
- Initialize({3, 2, -1, 0});
- auto session = CreateSession();
- ASSERT_TRUE(session != nullptr);
-
- DebugGateway debug_gateway(session.get());
-
- // Supply completion and value callbacks
- mutex mu;
- // Completed nodes with and without outputs
- std::vector<string> completed_nodes_w_outputs;
- std::vector<string> completed_nodes_wo_outputs;
-
- Notification callbacks_done;
- debug_gateway.SetNodeCompletionCallback(
- [&mu, &completed_nodes_w_outputs, &completed_nodes_wo_outputs](
- const string& node_name, const bool any_output) {
- mutex_lock l(mu);
- if (any_output) {
- completed_nodes_w_outputs.push_back(node_name);
- } else {
- completed_nodes_wo_outputs.push_back(node_name);
- }
- });
-
- std::vector<bool> tensors_initialized;
- std::unordered_map<string, Tensor> tensor_vals;
- // output_slot values recorded in value callbacks
- std::vector<int> output_slots_val;
- // is_ref values recorded in value callbacks
- std::vector<bool> is_refs_val;
-
- debug_gateway.SetNodeValueCallback(
- [this, &mu, &tensors_initialized, &tensor_vals, &output_slots_val,
- &is_refs_val,
- &callbacks_done](const string& node_name, const int output_slot,
- const Tensor& tensor_value, const bool is_ref) {
- mutex_lock l(mu);
- tensors_initialized.push_back(tensor_value.IsInitialized());
- tensor_vals.insert(std::make_pair(node_name, tensor_value));
- output_slots_val.push_back(output_slot);
- is_refs_val.push_back(is_ref);
-
- // Set the notification once we have the value from the target node.
- if (node_name == y_neg_ && !callbacks_done.HasBeenNotified()) {
- callbacks_done.Notify();
- }
- });
-
- TF_ASSERT_OK(session->Create(def_));
-
- std::vector<std::pair<string, Tensor>> inputs;
-
- // Request two targets: one fetch output and one non-fetched output.
- std::vector<string> output_names = {y_ + ":0"};
- std::vector<string> target_nodes = {y_neg_};
- std::vector<Tensor> outputs;
- Status s = session->Run(inputs, output_names, target_nodes, &outputs);
- TF_ASSERT_OK(s);
-
- // Wait for callbacks to complete.
- callbacks_done.WaitForNotification();
-
- ASSERT_EQ(1, outputs.size());
- // The first output should be initialized and have the correct
- // output.
- auto mat = outputs[0].matrix<float>();
- ASSERT_TRUE(outputs[0].IsInitialized());
- EXPECT_FLOAT_EQ(5.0, mat(0, 0));
-
- // Verify the calling history of the completion callback
- // The following verifies each node with output(s) invoked the callback
- // exactly once.
- ASSERT_GE(completed_nodes_w_outputs.size(), 4); // There may be added nodes.
-
- ASSERT_EQ(1, std::count(completed_nodes_w_outputs.begin(),
- completed_nodes_w_outputs.end(), a_));
- ASSERT_EQ(1, std::count(completed_nodes_w_outputs.begin(),
- completed_nodes_w_outputs.end(), x_));
- ASSERT_EQ(1, std::count(completed_nodes_w_outputs.begin(),
- completed_nodes_w_outputs.end(), y_));
- ASSERT_EQ(1, std::count(completed_nodes_w_outputs.begin(),
- completed_nodes_w_outputs.end(), y_neg_));
-
- // Apart from nodes with outputs, there are also no-output (control) nodes.
- // They ought to be captured by the DebugGateway through
- // NodeOutputCallback as well.
- ASSERT_GT(completed_nodes_wo_outputs.size(), 0);
-
- // The DebugGateway should have captured the _SOURCE node.
- ASSERT_LE(1, std::count(completed_nodes_wo_outputs.begin(),
- completed_nodes_wo_outputs.end(), "_SOURCE"));
-
- // Verify the calling history of the value callabck
- ASSERT_EQ(completed_nodes_w_outputs.size(), tensors_initialized.size());
-
- // In this graph, there is no uninitialized node value.
- ASSERT_EQ(
- tensors_initialized.end(),
- std::find(tensors_initialized.begin(), tensors_initialized.end(), false));
-
- ASSERT_EQ(completed_nodes_w_outputs.size(), tensor_vals.size());
- ASSERT_EQ(completed_nodes_w_outputs.size(), output_slots_val.size());
- ASSERT_EQ(completed_nodes_w_outputs.size(), is_refs_val.size());
-
- // Verify the intermediate tensor values captured through the value callback
- auto mat_a = tensor_vals[a_].matrix<float>();
- ASSERT_EQ(3.0, mat_a(0, 0));
- ASSERT_EQ(2.0, mat_a(0, 1));
- ASSERT_EQ(-1.0, mat_a(1, 0));
- ASSERT_EQ(0.0, mat_a(1, 1));
-
- auto mat_x = tensor_vals[x_].matrix<float>();
- ASSERT_EQ(1.0, mat_x(0, 0));
- ASSERT_EQ(1.0, mat_x(1, 0));
-
- auto mat_y = tensor_vals[y_].matrix<float>();
- ASSERT_EQ(5.0, mat_y(0, 0));
- ASSERT_EQ(-1.0, mat_y(1, 0));
-
- auto mat_y_neg = tensor_vals[y_neg_].matrix<float>();
- ASSERT_EQ(-5.0, mat_y_neg(0, 0));
- ASSERT_EQ(1.0, mat_y_neg(1, 0));
-
- // In this graph, all outputs are on the first slot
- ASSERT_EQ(output_slots_val.size(),
- std::count_if(output_slots_val.begin(), output_slots_val.end(),
- [](int slot) { return slot == 0; }));
-
- // In this graph, there is no ref-type tensor.
- ASSERT_EQ(is_refs_val.end(),
- std::find(is_refs_val.begin(), is_refs_val.end(), true));
-}
-
-TEST_F(SessionDebugMinusAXTest, RunSimpleNetworkWithTwoDebugNodesInserted) {
- // Tensor contains one count of NaN
- Initialize({3, std::numeric_limits<float>::quiet_NaN(), -1, 0});
- auto session = CreateSession();
- ASSERT_TRUE(session != nullptr);
-
- DebugGateway debug_gateway(session.get());
-
- // Create debug tensor watch options with two debug ops:
- // DebugIdentity and DebugNanCount
- RunOptions run_opts;
- run_opts.set_output_partition_graphs(true);
-
- const string debug_identity = "DebugIdentity";
- const string debug_nan_count = "DebugNanCount";
- DebugTensorWatch* tensor_watch_opts =
- run_opts.mutable_debug_options()->add_debug_tensor_watch_opts();
- tensor_watch_opts->set_node_name(y_);
- tensor_watch_opts->set_output_slot(0);
- tensor_watch_opts->add_debug_ops(debug_identity);
- tensor_watch_opts->add_debug_ops(debug_nan_count);
-
- // Expected name of the inserted debug node
- string debug_identity_node_name = DebugNodeInserter::GetDebugNodeName(
- strings::StrCat(y_, ":", 0), 0, debug_identity);
- string debug_nan_count_node_name = DebugNodeInserter::GetDebugNodeName(
- strings::StrCat(y_, ":", 0), 1, debug_nan_count);
-
- // Supply completion and value callbacks
- mutex mu;
- // Completed nodes with and without outputs
- std::vector<string> completed_debug_nodes;
-
- Notification callbacks_done;
- debug_gateway.SetNodeCompletionCallback(
- [&mu, &debug_identity_node_name, &debug_nan_count_node_name,
- &completed_debug_nodes](const string& node_name, const bool any_output) {
- mutex_lock l(mu);
- if (any_output && (node_name == debug_identity_node_name ||
- node_name == debug_nan_count_node_name)) {
- completed_debug_nodes.push_back(node_name);
- }
- });
-
- std::vector<Tensor> watched_tensor_vals;
- std::vector<Tensor> debug_identity_tensor_vals;
- std::vector<Tensor> debug_nan_count_tensor_vals;
-
- debug_gateway.SetNodeValueCallback(
- [this, &mu, &debug_identity_node_name, &debug_nan_count_node_name,
- &watched_tensor_vals, &debug_identity_tensor_vals,
- &debug_nan_count_tensor_vals,
- &callbacks_done](const string& node_name, const int output_slot,
- const Tensor& tensor_value, const bool is_ref) {
- mutex_lock l(mu);
- if (node_name == y_) {
- watched_tensor_vals.push_back(tensor_value);
- } else if (node_name == debug_identity_node_name && output_slot == 0) {
- // output_slot == 0 carries the debug signal. Same below.
- debug_identity_tensor_vals.push_back(tensor_value);
- } else if (node_name == debug_nan_count_node_name && output_slot == 0) {
- debug_nan_count_tensor_vals.push_back(tensor_value);
- }
-
- // Set the notification once we have the value from the target node.
- if (node_name == y_neg_ && !callbacks_done.HasBeenNotified()) {
- callbacks_done.Notify();
- }
- });
-
- TF_ASSERT_OK(session->Create(def_));
-
- std::vector<std::pair<string, Tensor>> inputs;
-
- // Request two targets: one fetch output and one non-fetched output.
- std::vector<string> output_names = {y_ + ":0"};
- std::vector<string> target_nodes = {y_neg_};
- std::vector<Tensor> outputs;
-
- RunMetadata run_metadata;
- Status s = session->Run(run_opts, inputs, output_names, target_nodes,
- &outputs, &run_metadata);
- TF_ASSERT_OK(s);
-
-// Verify the correct number of partition graphs (GraphDefs) outputted
-// through RunMetadata, given whether GPU is involved.
-#if GOOGLE_CUDA
- ASSERT_EQ(2, run_metadata.partition_graphs().size());
-#elif defined(TENSORFLOW_USE_SYCL)
- ASSERT_EQ(2, run_metadata.partition_graphs().size());
-#else
- ASSERT_EQ(1, run_metadata.partition_graphs().size());
-#endif
-
- // Wait for callbacks to complete.
- callbacks_done.WaitForNotification();
-
- // Verify that each of the two debug nodes has completed exactly once.
- ASSERT_EQ(2, completed_debug_nodes.size());
- ASSERT_EQ(
- 1, std::count(completed_debug_nodes.begin(), completed_debug_nodes.end(),
- debug_identity_node_name));
- ASSERT_EQ(
- 1, std::count(completed_debug_nodes.begin(), completed_debug_nodes.end(),
- debug_nan_count_node_name));
-
- // Verify that the tensor values from the watched node and the identity
- // debug node are received and they are equal (owing to the debug op being
- // "DebugIdentity")
- ASSERT_EQ(1, watched_tensor_vals.size());
- ASSERT_EQ(1, debug_identity_tensor_vals.size());
- auto mat_y = watched_tensor_vals[0].matrix<float>();
- auto mat_identity = debug_identity_tensor_vals[0].matrix<float>();
- // ASSERT_EQ doesn't work for nan == nan
- ASSERT_TRUE(std::isnan(mat_y(0, 0)));
- ASSERT_TRUE(std::isnan(mat_identity(0, 0)));
- ASSERT_EQ(-1, mat_identity(1, 0));
-
- // Verify that the output from the NaN-count debug node indicates exactly
- // one NaN.
- ASSERT_EQ(1, debug_nan_count_tensor_vals.size());
- ASSERT_EQ(1, debug_nan_count_tensor_vals[0].scalar<int64>()());
-}
-
-#if !defined(GOOGLE_CUDA) && !defined(TENSORFLOW_USE_SYCL)
-// TODO(cais): Reinstate the following test for concurrent debugged runs on
-// a GPU once the root cause of the ~0.5% flakiness has been addressed.
-// (b/34081273)
-TEST_F(SessionDebugMinusAXTest,
- RunSimpleNetworkConcurrentlyWithDifferentDebugTensorWatches) {
- // Test concurrent Run() calls on a graph with different debug watches.
-
- Initialize({3, 2, -1, 0});
- auto session = CreateSession();
- ASSERT_TRUE(session != nullptr);
- TF_ASSERT_OK(session->Create(def_));
-
- // Number of concurrent Run() calls to launch.
- const int kConcurrentRuns = 3;
- thread::ThreadPool* tp =
- new thread::ThreadPool(Env::Default(), "test", kConcurrentRuns);
-
- std::vector<string> output_names = {y_ + ":0"};
- std::vector<string> target_nodes = {y_neg_};
-
- mutex mu;
- DebugGateway debug_gateway(session.get());
- std::unordered_map<string, Tensor> debug_identity_tensor_vals;
-
- const string debug_identity = "DebugIdentity";
-
- const string a_debug_identity_node_name = DebugNodeInserter::GetDebugNodeName(
- strings::StrCat(a_, ":", 0), 0, debug_identity);
- const string x_debug_identity_node_name = DebugNodeInserter::GetDebugNodeName(
- strings::StrCat(x_, ":", 0), 0, debug_identity);
- const string y_debug_identity_node_name = DebugNodeInserter::GetDebugNodeName(
- strings::StrCat(y_, ":", 0), 0, debug_identity);
-
- Notification callbacks_done;
- volatile int val_callback_count = 0;
-
- debug_gateway.SetNodeValueCallback(
- [this, &mu, &val_callback_count, &a_debug_identity_node_name,
- &x_debug_identity_node_name, &y_debug_identity_node_name,
- &debug_identity_tensor_vals, &callbacks_done,
- &kConcurrentRuns](const string& node_name, const int output_slot,
- const Tensor& tensor_value, const bool is_ref) {
- mutex_lock l(mu);
-
- if (node_name == a_debug_identity_node_name && output_slot == 0) {
- debug_identity_tensor_vals["a"] = tensor_value;
- val_callback_count++;
- } else if (node_name == x_debug_identity_node_name &&
- output_slot == 0) {
- // output_slot == 0 carries the debug signal.
- debug_identity_tensor_vals["x"] = tensor_value;
- val_callback_count++;
- } else if (node_name == y_debug_identity_node_name &&
- output_slot == 0) {
- debug_identity_tensor_vals["y"] = tensor_value;
- val_callback_count++;
- }
-
- // Set the notification once we have the value from the callbacks from
- // all the concurrent Run() calls.
- if (val_callback_count == kConcurrentRuns &&
- !callbacks_done.HasBeenNotified()) {
- callbacks_done.Notify();
- }
- });
-
- int run_counter = 0;
- mutex run_lock;
-
- // Function to be executed concurrently.
- auto fn = [this, &run_lock, &run_counter, &session, output_names,
- target_nodes, &debug_identity]() {
- // Create unique debug tensor watch options for each of the concurrent
- // run calls.
- RunOptions run_opts;
- run_opts.set_output_partition_graphs(true);
-
- DebugTensorWatch* tensor_watch_opts =
- run_opts.mutable_debug_options()->add_debug_tensor_watch_opts();
- tensor_watch_opts->set_output_slot(0);
- tensor_watch_opts->add_debug_ops(debug_identity);
-
- {
- // Let the concurrent runs watch different tensors.
-
- mutex_lock l(run_lock);
-
- if (run_counter == 0) {
- // Let the 1st concurrent run watch a.
- tensor_watch_opts->set_node_name(a_);
- } else if (run_counter == 1) {
- // Let the 2nd concurrent watch x.
- tensor_watch_opts->set_node_name(x_);
- } else if (run_counter == 2) {
- // Let the 3rd concurrent watch y.
- tensor_watch_opts->set_node_name(y_);
- }
-
- run_counter++;
- }
-
- // Run the graph.
- RunMetadata run_metadata;
- std::vector<std::pair<string, Tensor>> inputs;
- std::vector<Tensor> outputs;
- Status s = session->Run(run_opts, inputs, output_names, target_nodes,
- &outputs, &run_metadata);
- TF_ASSERT_OK(s);
-
- ASSERT_EQ(1, run_metadata.partition_graphs().size());
-
- ASSERT_EQ(1, outputs.size());
- ASSERT_TRUE(outputs[0].IsInitialized());
- ASSERT_EQ(TensorShape({2, 1}), outputs[0].shape());
- auto mat = outputs[0].matrix<float>();
- EXPECT_FLOAT_EQ(5.0, mat(0, 0));
- EXPECT_FLOAT_EQ(-1.0, mat(1, 0));
- };
-
- for (int i = 0; i < kConcurrentRuns; ++i) {
- tp->Schedule(fn);
- }
-
- // Wait for the debug callbacks to finish.
- callbacks_done.WaitForNotification();
-
- // Wait for the concurrent functions with Run() calls to finish.
- delete tp;
-
- {
- mutex_lock l(mu);
-
- ASSERT_EQ(kConcurrentRuns, val_callback_count);
- ASSERT_EQ(kConcurrentRuns, debug_identity_tensor_vals.size());
-
- ASSERT_EQ(TensorShape({2, 2}), debug_identity_tensor_vals["a"].shape());
- auto a_mat_identity = debug_identity_tensor_vals["a"].matrix<float>();
- ASSERT_EQ(3.0, a_mat_identity(0, 0));
- ASSERT_EQ(2.0, a_mat_identity(0, 1));
- ASSERT_EQ(-1.0, a_mat_identity(1, 0));
- ASSERT_EQ(0.0, a_mat_identity(1, 1));
-
- ASSERT_EQ(TensorShape({2, 1}), debug_identity_tensor_vals["x"].shape());
- auto x_mat_identity = debug_identity_tensor_vals["x"].matrix<float>();
- ASSERT_EQ(1.0, x_mat_identity(0, 0));
- ASSERT_EQ(1.0, x_mat_identity(1, 0));
-
- ASSERT_EQ(TensorShape({2, 1}), debug_identity_tensor_vals["y"].shape());
- auto y_mat_identity = debug_identity_tensor_vals["y"].matrix<float>();
- ASSERT_EQ(5.0, y_mat_identity(0, 0));
- ASSERT_EQ(-1.0, y_mat_identity(1, 0));
- }
-}
-#endif
-
-class SessionDebugOutputSlotWithoutOutgoingEdgeTest : public ::testing::Test {
- public:
- void Initialize() {
- Graph graph(OpRegistry::Global());
-
-#if GOOGLE_CUDA
- const string kDeviceName = "/job:localhost/replica:0/task:0/device:GPU:0";
-#elif defined(TENSORFLOW_USE_SYCL)
- const string kDeviceName = "/job:localhost/replica:0/task:0/device:SYCL:0";
-#else
- const string kDeviceName = "/job:localhost/replica:0/task:0/device:CPU:0";
-#endif
-
- Tensor a_tensor(DT_FLOAT, TensorShape({1, 1}));
- test::FillValues<float>(&a_tensor, {42.0});
- Node* a = test::graph::Constant(&graph, a_tensor);
- a->set_assigned_device_name(kDeviceName);
-
- Node* c = test::graph::Constant(&graph, a_tensor);
- c->set_assigned_device_name(kDeviceName);
- c_ = c->name();
-
- // Node c will be executed only because of the control edge from c to y.
- // Its output slot (slot 0) does not have an outgoing edge. This test
- // is for testing that the debugger can watch that slot properly.
- Node* y = test::graph::NoOp(&graph, {c});
- y->set_assigned_device_name(kDeviceName);
- y_ = y->name();
-
- test::graph::ToGraphDef(&graph, &def_);
- }
-
- string c_;
- string y_;
- GraphDef def_;
-};
-
-TEST_F(SessionDebugOutputSlotWithoutOutgoingEdgeTest,
- WatchSlotWithoutOutgoingEdge) {
- Initialize();
- auto session = CreateSession();
- ASSERT_TRUE(session != nullptr);
-
- DebugGateway debug_gateway(session.get());
-
- // Supply completion and value callbacks
- mutex mu;
-
- string debug_identity_node_name = DebugNodeInserter::GetDebugNodeName(
- strings::StrCat(c_, ":", 0), 0, "DebugIdentity");
-
- Notification callbacks_done;
-
- std::vector<Tensor> debug_identity_tensor_vals;
- debug_gateway.SetNodeValueCallback(
- [this, &mu, &callbacks_done, &debug_identity_node_name,
- &debug_identity_tensor_vals](
- const string& node_name, const int output_slot,
- const Tensor& tensor_value, const bool is_ref) {
- mutex_lock l(mu);
-
- if (node_name == debug_identity_node_name && output_slot == 0) {
- debug_identity_tensor_vals.push_back(tensor_value);
-
- if (!callbacks_done.HasBeenNotified()) {
- callbacks_done.Notify();
- }
- }
- });
-
- // Add DebugIdentity watch on c:0, which does not have an outgoing edge.
- RunOptions run_opts;
- run_opts.set_output_partition_graphs(true);
-
- DebugTensorWatch* tensor_watch_opts =
- run_opts.mutable_debug_options()->add_debug_tensor_watch_opts();
- tensor_watch_opts->set_node_name(c_);
- tensor_watch_opts->set_output_slot(0);
- tensor_watch_opts->add_debug_ops("DebugIdentity");
-
- TF_ASSERT_OK(session->Create(def_));
-
- // Invoke Session::Run() on y.
- std::vector<std::pair<string, Tensor>> inputs;
- std::vector<string> output_names;
- std::vector<string> target_nodes = {y_};
- std::vector<Tensor> outputs;
-
- RunMetadata run_metadata;
- Status s = session->Run(run_opts, inputs, output_names, target_nodes,
- &outputs, &run_metadata);
- TF_ASSERT_OK(s);
-
- // Wait for callbacks to complete.
- callbacks_done.WaitForNotification();
-
- // Assert that DebugIdentity node watching the control edge has been run.
- ASSERT_EQ(1, debug_identity_tensor_vals.size());
- auto mat_identity = debug_identity_tensor_vals[0].matrix<float>();
- ASSERT_EQ(42.0, mat_identity(0, 0));
-}
-
-class SessionDebugVariableTest : public ::testing::Test {
- public:
- void Initialize() {
- Graph graph(OpRegistry::Global());
-
-#if GOOGLE_CUDA
- const string kDeviceName = "/job:localhost/replica:0/task:0/device:GPU:0";
-#elif defined(TENSORFLOW_USE_SYCL)
- const string kDeviceName = "/job:localhost/replica:0/task:0/device:SYCL:0";
-#else
- const string kDeviceName = "/job:localhost/replica:0/task:0/device:CPU:0";
-#endif
-
- // Define variable node.
- var_node_name_ = "var";
- Node* var =
- test::graph::Var(&graph, DT_FLOAT, TensorShape({3}), var_node_name_);
- var->set_assigned_device_name(kDeviceName);
-
- // Define the initial value and the initial-value node.
- Tensor nan_nan_seven(DT_FLOAT, TensorShape({3}));
- nan_nan_seven.flat<float>()(0) = std::numeric_limits<float>::quiet_NaN();
- nan_nan_seven.flat<float>()(1) = std::numeric_limits<float>::quiet_NaN();
- nan_nan_seven.flat<float>()(2) = 7.0;
-
- init_val_node_name_ = "init_val";
- Node* init_val =
- test::graph::Constant(&graph, nan_nan_seven, init_val_node_name_);
- init_val->set_assigned_device_name(kDeviceName);
-
- // Define node for variable value initialization
- Node* init = test::graph::Assign(&graph, var, init_val);
- init->set_assigned_device_name(kDeviceName);
- init_node_name_ = init->name();
-
- // Define new value node
- Tensor nan_eight_eight(DT_FLOAT, TensorShape({3}));
- nan_eight_eight.flat<float>()(0) = std::numeric_limits<float>::quiet_NaN();
- nan_eight_eight.flat<float>()(1) = 8.0;
- nan_eight_eight.flat<float>()(2) = 8.0;
-
- Node* new_val = test::graph::Constant(&graph, nan_eight_eight);
- new_val->set_assigned_device_name(kDeviceName);
- new_val_node_name_ = new_val->name();
-
- // Define node for assigning new value
- Node* assign = test::graph::Assign(&graph, var, new_val);
- assign->set_assigned_device_name(kDeviceName);
- assign_node_name_ = assign->name();
-
- test::graph::ToGraphDef(&graph, &def_);
- }
-
- string var_node_name_;
- string init_val_node_name_;
- string init_node_name_;
- string new_val_node_name_;
- string assign_node_name_;
- GraphDef def_;
-};
-
-TEST_F(SessionDebugVariableTest, WatchUninitializedVariableWithDebugOps) {
- Initialize();
- auto session = CreateSession();
- ASSERT_TRUE(session != nullptr);
-
- DebugGateway debug_gateway(session.get());
-
- TF_ASSERT_OK(session->Create(def_));
-
- // Set up DebugTensorWatch for an uninitialized tensor (in node var).
- RunOptions run_opts;
- const string debug_identity = "DebugIdentity";
- DebugTensorWatch* tensor_watch_opts =
- run_opts.mutable_debug_options()->add_debug_tensor_watch_opts();
- tensor_watch_opts->set_node_name(var_node_name_);
- tensor_watch_opts->set_output_slot(0);
- tensor_watch_opts->add_debug_ops(debug_identity);
-
- // Expected name of the inserted debug node
- string debug_identity_node_name = DebugNodeInserter::GetDebugNodeName(
- strings::StrCat(var_node_name_, ":", 0), 0, debug_identity);
-
- // Supply completion and value callbacks
- mutex mu;
- // Completed nodes with and without outputs
- std::vector<string> completed_debug_nodes;
-
- Notification callbacks_done;
- debug_gateway.SetNodeCompletionCallback(
- [this, &mu, &debug_identity_node_name, &completed_debug_nodes,
- &callbacks_done](const string& node_name, const bool any_output) {
- mutex_lock l(mu);
- if (any_output && (node_name == debug_identity_node_name)) {
- completed_debug_nodes.push_back(node_name);
- }
- });
-
- std::vector<Tensor> debug_identity_tensor_vals;
-
- debug_gateway.SetNodeValueCallback(
- [this, &mu, &debug_identity_node_name, &debug_identity_tensor_vals,
- &callbacks_done](const string& node_name, const int output_slot,
- const Tensor& tensor_value, const bool is_ref) {
- mutex_lock l(mu);
- if (node_name == debug_identity_node_name && output_slot == 0) {
- // output_slot == 0 carries the debug signal. Same below.
- debug_identity_tensor_vals.push_back(tensor_value);
- }
-
- // Set the notification once we have the value from the target node.
- if (node_name == init_node_name_ && !callbacks_done.HasBeenNotified()) {
- callbacks_done.Notify();
- }
- });
-
- // First run the initialization op
- std::vector<std::pair<string, Tensor>> inputs_init;
- std::vector<Tensor> outputs_init;
-
- RunMetadata run_metadata;
- Status s = session->Run(run_opts, inputs_init, {init_node_name_}, {},
- &outputs_init, &run_metadata);
- TF_ASSERT_OK(s);
-
- callbacks_done.WaitForNotification();
-
- ASSERT_EQ(1, completed_debug_nodes.size());
- ASSERT_EQ(
- 1, std::count(completed_debug_nodes.begin(), completed_debug_nodes.end(),
- debug_identity_node_name));
-
- // Assert the output reflects the uninitialized nature of var's tensor.
- ASSERT_EQ(1, debug_identity_tensor_vals.size());
- ASSERT_FALSE(debug_identity_tensor_vals[0].IsInitialized());
- ASSERT_EQ(DT_FLOAT, debug_identity_tensor_vals[0].dtype());
- ASSERT_EQ(TensorShape({3}), debug_identity_tensor_vals[0].shape());
-}
-
-TEST_F(SessionDebugVariableTest, VariableAssignWithDebugOps) {
- // Tensor contains one count of NaN
- Initialize();
- auto session = CreateSession();
- ASSERT_TRUE(session != nullptr);
-
- DebugGateway debug_gateway(session.get());
-
- TF_ASSERT_OK(session->Create(def_));
-
- // First run the initialization op
- std::vector<std::pair<string, Tensor>> inputs_init;
- std::vector<Tensor> outputs_init;
- Status s = session->Run(inputs_init, {init_node_name_}, {}, &outputs_init);
- TF_ASSERT_OK(s);
-
- // Create debug tensor watch options with two ref-type debug ops:
- // DebugIdentity and DebugNanCount
- RunOptions run_opts;
- run_opts.set_output_partition_graphs(true);
- const string debug_identity = "DebugIdentity";
- const string debug_nan_count = "DebugNanCount";
- DebugTensorWatch* tensor_watch_opts =
- run_opts.mutable_debug_options()->add_debug_tensor_watch_opts();
- tensor_watch_opts->set_node_name(var_node_name_);
- tensor_watch_opts->set_output_slot(0);
- tensor_watch_opts->add_debug_ops(debug_identity);
- tensor_watch_opts->add_debug_ops(debug_nan_count);
-
- char tempdir_template[] = "/tmp/tfdbg_XXXXXX";
- string temp_dir(mkdtemp(tempdir_template));
- tensor_watch_opts->add_debug_urls(strings::StrCat("file://", temp_dir));
-
- // Expected name of the inserted debug node
- string debug_identity_node_name = DebugNodeInserter::GetDebugNodeName(
- strings::StrCat(var_node_name_, ":", 0), 0, debug_identity);
- string debug_nan_count_node_name = DebugNodeInserter::GetDebugNodeName(
- strings::StrCat(var_node_name_, ":", 0), 1, debug_nan_count);
-
- // Supply completion and value callbacks
- mutex mu;
- // Completed nodes with and without outputs
- std::vector<string> completed_debug_nodes;
-
- Notification callbacks_done;
- debug_gateway.SetNodeCompletionCallback(
- [this, &mu, &debug_identity_node_name, &debug_nan_count_node_name,
- &completed_debug_nodes,
- &callbacks_done](const string& node_name, const bool any_output) {
- mutex_lock l(mu);
- if (any_output && (node_name == debug_identity_node_name ||
- node_name == debug_nan_count_node_name)) {
- completed_debug_nodes.push_back(node_name);
- }
- });
-
- std::vector<Tensor> debug_identity_tensor_vals;
- std::vector<Tensor> debug_nan_count_tensor_vals;
-
- debug_gateway.SetNodeValueCallback(
- [this, &mu, &debug_identity_node_name, &debug_nan_count_node_name,
- &debug_identity_tensor_vals, &debug_nan_count_tensor_vals,
- &callbacks_done](const string& node_name, const int output_slot,
- const Tensor& tensor_value, const bool is_ref) {
- mutex_lock l(mu);
- if (node_name == debug_identity_node_name && output_slot == 0) {
- // output_slot == 0 carries the debug signal. Same below.
- debug_identity_tensor_vals.push_back(tensor_value);
- } else if (node_name == debug_nan_count_node_name && output_slot == 0) {
- debug_nan_count_tensor_vals.push_back(tensor_value);
- }
-
- // Set the notification once we have the value from the target node.
- if (node_name == assign_node_name_ &&
- !callbacks_done.HasBeenNotified()) {
- callbacks_done.Notify();
- }
- });
-
- // // Request two targets: one fetch output and one non-fetched output.
- std::vector<std::pair<string, Tensor>> inputs;
- std::vector<string> output_names = {assign_node_name_ + ":0"};
- std::vector<string> target_nodes = {assign_node_name_};
- std::vector<Tensor> outputs;
-
- // Run with RunOptions that has tensor watches
- RunMetadata run_metadata;
- s = session->Run(run_opts, inputs, output_names, target_nodes, &outputs,
- &run_metadata);
- TF_ASSERT_OK(s);
-
-#if GOOGLE_CUDA
- ASSERT_EQ(2, run_metadata.partition_graphs().size());
-#elif defined(TENSORFLOW_USE_SYCL)
- ASSERT_EQ(2, run_metadata.partition_graphs().size());
-#else
- ASSERT_EQ(1, run_metadata.partition_graphs().size());
-#endif
-
- // Wait for callbacks to complete.
- callbacks_done.WaitForNotification();
-
- // Verify that the update has happened properly.
- ASSERT_EQ(1, outputs.size());
- ASSERT_TRUE(std::isnan(outputs[0].vec<float>()(0)));
- ASSERT_EQ(8.0, outputs[0].vec<float>()(1)); // Expect new value
- ASSERT_EQ(8.0, outputs[0].vec<float>()(2)); // Expect new value
-
- // Verify that each of the two debug nodes has completed exactly once.
- ASSERT_EQ(2, completed_debug_nodes.size());
- ASSERT_EQ(
- 1, std::count(completed_debug_nodes.begin(), completed_debug_nodes.end(),
- debug_identity_node_name));
- ASSERT_EQ(
- 1, std::count(completed_debug_nodes.begin(), completed_debug_nodes.end(),
- debug_nan_count_node_name));
-
- // Verify that the values from the ref identity node reflects the value
- // before the new assign.
- ASSERT_EQ(1, debug_identity_tensor_vals.size());
-
- auto vec_identity = debug_identity_tensor_vals[0].vec<float>();
- ASSERT_TRUE(std::isnan(vec_identity(0)));
- ASSERT_TRUE(std::isnan(vec_identity(1)));
- ASSERT_EQ(7.0, vec_identity(2));
-
- // Verify that the output from the NaN-count debug node indicates exactly
- // two NaNs, i.e., reflecting the value before the new assign.
- ASSERT_EQ(1, debug_nan_count_tensor_vals.size());
- ASSERT_EQ(2, debug_nan_count_tensor_vals[0].scalar<int64>()());
-}
-
-#if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_SYCL)
-class SessionDebugGPUSwitchTest : public ::testing::Test {
- public:
- void Initialize() {
- Graph graph(OpRegistry::Global());
-
-#ifdef GOOGLE_CUDA
- const string kDeviceName = "/job:localhost/replica:0/task:0/device:GPU:0";
-#elif TENSORFLOW_USE_SYCL
- const string kDeviceName = "/job:localhost/replica:0/task:0/device:SYCL:0";
-#endif
-
- Tensor vb(DT_BOOL, TensorShape({}));
- vb.scalar<bool>()() = true;
- Tensor vi(DT_INT64, TensorShape({}));
- vi.scalar<int>()() = 42;
- // So vi is expected to be forwarded to the second output port of sw.
-
- Node* pred = test::graph::Constant(&graph, vb);
- pred->set_assigned_device_name(kDeviceName);
- pred_node_name_ = pred->name();
-
- Node* value = test::graph::Constant(&graph, vi);
- pred->set_assigned_device_name(kDeviceName);
- value_node_name_ = value->name();
-
- Node* sw = test::graph::Switch(&graph, value, pred);
- sw->set_assigned_device_name(kDeviceName);
- sw_node_name_ = sw->name();
-
- Node* z = test::graph::Identity(&graph, sw, 1);
- sw->set_assigned_device_name(kDeviceName);
- z_node_name_ = z->name();
-
- test::graph::ToGraphDef(&graph, &def_);
- }
-
- string pred_node_name_;
- string value_node_name_;
- string sw_node_name_;
- string z_node_name_;
- GraphDef def_;
-};
-
-// Test for debug-watching tensors marked as HOST_MEMORY on GPU.
-TEST_F(SessionDebugGPUSwitchTest, RunSwitchWithHostMemoryDebugOp) {
- Initialize();
- auto session = CreateSession();
- ASSERT_TRUE(session != nullptr);
-
- DebugGateway debug_gateway(session.get());
-
- RunOptions run_opts;
- run_opts.set_output_partition_graphs(true);
- // This is the name of the boolean tensor fed as pred to the Switch node.
- // On GPU, this edge is HOST_MEMORY.
- const string watched_tensor = strings::StrCat(pred_node_name_, "/_1");
-
- const string debug_identity = "DebugIdentity";
- DebugTensorWatch* tensor_watch_opts =
- run_opts.mutable_debug_options()->add_debug_tensor_watch_opts();
- tensor_watch_opts->set_node_name(watched_tensor);
- tensor_watch_opts->set_output_slot(0);
- tensor_watch_opts->add_debug_ops(debug_identity);
-
- // Expected name of the inserted debug node
- string debug_identity_node_name = DebugNodeInserter::GetDebugNodeName(
- strings::StrCat(watched_tensor, ":", 0), 0, debug_identity);
-
- // Supply completion and value callbacks
- mutex mu;
- // Completed nodes with and without outputs
- std::vector<string> completed_nodes_w_outputs;
- std::vector<string> completed_nodes_wo_outputs;
-
- Notification callbacks_done;
- debug_gateway.SetNodeCompletionCallback(
- [&mu, &completed_nodes_w_outputs, &completed_nodes_wo_outputs](
- const string& node_name, const bool any_output) {
- mutex_lock l(mu);
- if (any_output) {
- completed_nodes_w_outputs.push_back(node_name);
- } else {
- completed_nodes_wo_outputs.push_back(node_name);
- }
- });
-
- std::vector<Tensor> debug_identity_tensor_vals;
-
- debug_gateway.SetNodeValueCallback(
- [this, &mu, &debug_identity_node_name, &debug_identity_tensor_vals,
- &callbacks_done](const string& node_name, const int output_slot,
- const Tensor& tensor_value, const bool is_ref) {
- mutex_lock l(mu);
- if (node_name == debug_identity_node_name && output_slot == 0) {
- debug_identity_tensor_vals.push_back(tensor_value);
- }
-
- // Set the notification once we have the value from the target node.
- if (node_name == z_node_name_ && !callbacks_done.HasBeenNotified()) {
- callbacks_done.Notify();
- }
- });
-
- TF_ASSERT_OK(session->Create(def_));
-
- std::vector<std::pair<string, Tensor>> inputs;
-
- // Request two targets: one fetch output and one non-fetched output.
- std::vector<string> output_names = {z_node_name_ + ":0"};
- std::vector<string> target_nodes = {z_node_name_};
- std::vector<Tensor> outputs;
-
- RunMetadata run_metadata;
- Status s = session->Run(run_opts, inputs, output_names, target_nodes,
- &outputs, &run_metadata);
- TF_ASSERT_OK(s);
-
- ASSERT_EQ(2, run_metadata.partition_graphs().size());
-
- // Wait for callbacks to complete.
- callbacks_done.WaitForNotification();
-
- ASSERT_EQ(1, debug_identity_tensor_vals.size());
- ASSERT_TRUE(debug_identity_tensor_vals[0].scalar<bool>()());
-}
-#endif // GOOGLE_CUDA
-
-} // end namespace
-} // end namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD
index 75f8a19e9c..2059b1ce0d 100644
--- a/tensorflow/core/distributed_runtime/BUILD
+++ b/tensorflow/core/distributed_runtime/BUILD
@@ -494,9 +494,11 @@ tf_cc_test(
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
"//tensorflow/core:session_options",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "//tensorflow/core:worker_proto_cc",
],
)
@@ -636,12 +638,12 @@ tf_cuda_cc_test(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:master_proto_cc",
- "//tensorflow/core:master_service_proto_cc",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/distributed_runtime/rpc:grpc_channel",
+ "//tensorflow/core/distributed_runtime/rpc:grpc_master_service_impl",
"//tensorflow/core/distributed_runtime/rpc:grpc_testlib",
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache",
diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc
index e2f13df19f..6c146036ae 100644
--- a/tensorflow/core/distributed_runtime/graph_mgr.cc
+++ b/tensorflow/core/distributed_runtime/graph_mgr.cc
@@ -261,7 +261,7 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
optimizer.Optimize(lib, worker_env_->env, params.device, &subgraph,
/*shape_map=*/nullptr);
- // EXPERIMENTAL: tfdbg inserts debug nodes (i.e., probes) to the graph.
+ // TensorFlow Debugger (tfdbg) inserts debug nodes in the graph.
if (!debug_options.debug_tensor_watch_opts().empty()) {
TF_RETURN_IF_ERROR(DecorateAndPublishGraphForDebug(
debug_options, subgraph.get(), params.device));
diff --git a/tensorflow/core/distributed_runtime/master_test.cc b/tensorflow/core/distributed_runtime/master_test.cc
index 09e96cbd40..62b18a45b1 100644
--- a/tensorflow/core/distributed_runtime/master_test.cc
+++ b/tensorflow/core/distributed_runtime/master_test.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "grpcpp/grpcpp.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_testlib.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/framework/allocator.h"
@@ -37,7 +38,6 @@ limitations under the License.
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/master.pb.h"
-#include "tensorflow/core/protobuf/master_service.grpc.pb.h"
namespace tensorflow {
diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD
index d6c493c022..4a10d99a60 100644
--- a/tensorflow/core/distributed_runtime/rpc/BUILD
+++ b/tensorflow/core/distributed_runtime/rpc/BUILD
@@ -201,11 +201,11 @@ cc_library(
srcs = ["grpc_remote_master.cc"],
hdrs = ["grpc_remote_master.h"],
deps = [
+ ":grpc_master_service_impl",
":grpc_util",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:master_proto_cc",
- "//tensorflow/core:master_service_proto_cc",
"//tensorflow/core/distributed_runtime:call_options",
"//tensorflow/core/distributed_runtime:master_interface",
],
@@ -219,18 +219,28 @@ cc_library(
deps = [
":async_service_interface",
":grpc_call",
+ ":grpc_master_service_impl",
":grpc_util",
"//tensorflow:grpc++",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:master_proto_cc",
- "//tensorflow/core:master_service_proto_cc",
"//tensorflow/core/distributed_runtime:master",
],
alwayslink = 1,
)
cc_library(
+ name = "grpc_master_service_impl",
+ srcs = ["grpc_master_service_impl.cc"],
+ hdrs = ["grpc_master_service_impl.h"],
+ deps = [
+ "//tensorflow:grpc++",
+ "//tensorflow/core:master_proto_cc",
+ ],
+)
+
+cc_library(
name = "rpc_rendezvous_mgr",
srcs = ["rpc_rendezvous_mgr.cc"],
hdrs = ["rpc_rendezvous_mgr.h"],
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc b/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc
index 0ebc084cb6..b7eb3c9015 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc
@@ -42,12 +42,12 @@ string MakeAddress(const string& job, int task) {
return strings::StrCat("/job:", job, "/replica:0/task:", task);
}
+// Allows the host to be a raw IP (either v4 or v6).
Status ValidateHostPortPair(const string& host_port) {
uint32 port;
- std::vector<string> parts = str_util::Split(host_port, ':');
- // Must be host:port, port must be a number, host must not contain a '/'.
- if (parts.size() != 2 || !strings::safe_strtou32(parts[1], &port) ||
- parts[0].find("/") != string::npos) {
+ auto colon_index = host_port.find_last_of(':');
+ if (!strings::safe_strtou32(host_port.substr(colon_index + 1), &port) ||
+ host_port.substr(0, colon_index).find("/") != string::npos) {
return errors::InvalidArgument("Could not interpret \"", host_port,
"\" as a host-port pair.");
}
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc
index a17acc85b3..f07a5a0974 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc
@@ -150,10 +150,15 @@ TEST(GrpcChannelTest, NewHostPortGrpcChannelValidation) {
EXPECT_TRUE(NewHostPortGrpcChannel("127.0.0.1:2222", &mock_ptr).ok());
EXPECT_TRUE(NewHostPortGrpcChannel("example.com:2222", &mock_ptr).ok());
EXPECT_TRUE(NewHostPortGrpcChannel("fqdn.example.com.:2222", &mock_ptr).ok());
+ EXPECT_TRUE(NewHostPortGrpcChannel("[2002:a9c:258e::]:2222", &mock_ptr).ok());
+ EXPECT_TRUE(NewHostPortGrpcChannel("[::]:2222", &mock_ptr).ok());
EXPECT_FALSE(NewHostPortGrpcChannel("example.com/abc:2222", &mock_ptr).ok());
EXPECT_FALSE(NewHostPortGrpcChannel("127.0.0.1:2222/", &mock_ptr).ok());
EXPECT_FALSE(NewHostPortGrpcChannel("example.com/abc:", &mock_ptr).ok());
+ EXPECT_FALSE(NewHostPortGrpcChannel("[::]/:2222", &mock_ptr).ok());
+ EXPECT_FALSE(NewHostPortGrpcChannel("[::]:2222/", &mock_ptr).ok());
+ EXPECT_FALSE(NewHostPortGrpcChannel("[::]:", &mock_ptr).ok());
}
} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc
index 2c2c1d484a..127dea2882 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc
@@ -36,12 +36,12 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/master.h"
#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/tracing.h"
#include "tensorflow/core/protobuf/master.pb.h"
-#include "tensorflow/core/protobuf/master_service.grpc.pb.h"
namespace tensorflow {
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc
new file mode 100644
index 0000000000..770a0fcf14
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc
@@ -0,0 +1,164 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h"
+
+#include "grpcpp/impl/codegen/async_stream.h"
+#include "grpcpp/impl/codegen/async_unary_call.h"
+#include "grpcpp/impl/codegen/channel_interface.h"
+#include "grpcpp/impl/codegen/client_unary_call.h"
+#include "grpcpp/impl/codegen/method_handler_impl.h"
+#include "grpcpp/impl/codegen/rpc_service_method.h"
+#include "grpcpp/impl/codegen/service_type.h"
+#include "grpcpp/impl/codegen/sync_stream.h"
+
+namespace tensorflow {
+
+namespace grpc {
+
+static const char* grpcMasterService_method_names[] = {
+ "/tensorflow.MasterService/CreateSession",
+ "/tensorflow.MasterService/ExtendSession",
+ "/tensorflow.MasterService/PartialRunSetup",
+ "/tensorflow.MasterService/RunStep",
+ "/tensorflow.MasterService/CloseSession",
+ "/tensorflow.MasterService/ListDevices",
+ "/tensorflow.MasterService/Reset",
+ "/tensorflow.MasterService/MakeCallable",
+ "/tensorflow.MasterService/RunCallable",
+ "/tensorflow.MasterService/ReleaseCallable",
+};
+
+std::unique_ptr<MasterService::Stub> MasterService::NewStub(
+ const std::shared_ptr< ::grpc::ChannelInterface>& channel,
+ const ::grpc::StubOptions& options) {
+ std::unique_ptr<MasterService::Stub> stub(new MasterService::Stub(channel));
+ return stub;
+}
+
+MasterService::Stub::Stub(
+ const std::shared_ptr< ::grpc::ChannelInterface>& channel)
+ : channel_(channel),
+ rpcmethod_CreateSession_(grpcMasterService_method_names[0],
+ ::grpc::internal::RpcMethod::NORMAL_RPC,
+ channel),
+ rpcmethod_ExtendSession_(grpcMasterService_method_names[1],
+ ::grpc::internal::RpcMethod::NORMAL_RPC,
+ channel),
+ rpcmethod_PartialRunSetup_(grpcMasterService_method_names[2],
+ ::grpc::internal::RpcMethod::NORMAL_RPC,
+ channel),
+ rpcmethod_RunStep_(grpcMasterService_method_names[3],
+ ::grpc::internal::RpcMethod::NORMAL_RPC, channel),
+ rpcmethod_CloseSession_(grpcMasterService_method_names[4],
+ ::grpc::internal::RpcMethod::NORMAL_RPC, channel),
+ rpcmethod_ListDevices_(grpcMasterService_method_names[5],
+ ::grpc::internal::RpcMethod::NORMAL_RPC, channel),
+ rpcmethod_Reset_(grpcMasterService_method_names[6],
+ ::grpc::internal::RpcMethod::NORMAL_RPC, channel),
+ rpcmethod_MakeCallable_(grpcMasterService_method_names[7],
+ ::grpc::internal::RpcMethod::NORMAL_RPC, channel),
+ rpcmethod_RunCallable_(grpcMasterService_method_names[8],
+ ::grpc::internal::RpcMethod::NORMAL_RPC, channel),
+ rpcmethod_ReleaseCallable_(grpcMasterService_method_names[9],
+ ::grpc::internal::RpcMethod::NORMAL_RPC,
+ channel) {}
+
+::grpc::Status MasterService::Stub::CreateSession(
+ ::grpc::ClientContext* context, const CreateSessionRequest& request,
+ CreateSessionResponse* response) {
+ return ::grpc::internal::BlockingUnaryCall(
+ channel_.get(), rpcmethod_CreateSession_, context, request, response);
+}
+
+::grpc::Status MasterService::Stub::ExtendSession(
+ ::grpc::ClientContext* context, const ExtendSessionRequest& request,
+ ExtendSessionResponse* response) {
+ return ::grpc::internal::BlockingUnaryCall(
+ channel_.get(), rpcmethod_ExtendSession_, context, request, response);
+}
+
+::grpc::Status MasterService::Stub::PartialRunSetup(
+ ::grpc::ClientContext* context, const PartialRunSetupRequest& request,
+ PartialRunSetupResponse* response) {
+ return ::grpc::internal::BlockingUnaryCall(
+ channel_.get(), rpcmethod_PartialRunSetup_, context, request, response);
+}
+
+::grpc::Status MasterService::Stub::RunStep(::grpc::ClientContext* context,
+ const RunStepRequest& request,
+ RunStepResponse* response) {
+ return ::grpc::internal::BlockingUnaryCall(channel_.get(), rpcmethod_RunStep_,
+ context, request, response);
+}
+
+::grpc::Status MasterService::Stub::CloseSession(
+ ::grpc::ClientContext* context, const CloseSessionRequest& request,
+ CloseSessionResponse* response) {
+ return ::grpc::internal::BlockingUnaryCall(
+ channel_.get(), rpcmethod_CloseSession_, context, request, response);
+}
+
+::grpc::Status MasterService::Stub::ListDevices(
+ ::grpc::ClientContext* context, const ListDevicesRequest& request,
+ ListDevicesResponse* response) {
+ return ::grpc::internal::BlockingUnaryCall(
+ channel_.get(), rpcmethod_ListDevices_, context, request, response);
+}
+
+::grpc::Status MasterService::Stub::Reset(::grpc::ClientContext* context,
+ const ResetRequest& request,
+ ResetResponse* response) {
+ return ::grpc::internal::BlockingUnaryCall(channel_.get(), rpcmethod_Reset_,
+ context, request, response);
+}
+
+::grpc::Status MasterService::Stub::MakeCallable(
+ ::grpc::ClientContext* context, const MakeCallableRequest& request,
+ MakeCallableResponse* response) {
+ return ::grpc::internal::BlockingUnaryCall(
+ channel_.get(), rpcmethod_MakeCallable_, context, request, response);
+}
+
+::grpc::Status MasterService::Stub::RunCallable(
+ ::grpc::ClientContext* context, const RunCallableRequest& request,
+ RunCallableResponse* response) {
+ return ::grpc::internal::BlockingUnaryCall(
+ channel_.get(), rpcmethod_RunCallable_, context, request, response);
+}
+
+::grpc::Status MasterService::Stub::ReleaseCallable(
+ ::grpc::ClientContext* context, const ReleaseCallableRequest& request,
+ ReleaseCallableResponse* response) {
+ return ::grpc::internal::BlockingUnaryCall(
+ channel_.get(), rpcmethod_ReleaseCallable_, context, request, response);
+}
+
+MasterService::AsyncService::AsyncService() {
+ int method_len = sizeof(grpcMasterService_method_names) /
+ sizeof(grpcMasterService_method_names[0]);
+ for (int i = 0; i < method_len; ++i) {
+ AddMethod(new ::grpc::internal::RpcServiceMethod(
+ grpcMasterService_method_names[i],
+ ::grpc::internal::RpcMethod::NORMAL_RPC, nullptr));
+ ::grpc::Service::MarkMethodAsync(i);
+ }
+}
+
+MasterService::AsyncService::~AsyncService() {}
+
+} // namespace grpc
+
+} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h
new file mode 100644
index 0000000000..751f2633e7
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h
@@ -0,0 +1,224 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_IMPL_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_IMPL_H_
+
+#include "grpcpp/impl/codegen/async_stream.h"
+#include "grpcpp/impl/codegen/async_unary_call.h"
+#include "grpcpp/impl/codegen/proto_utils.h"
+#include "grpcpp/impl/codegen/rpc_method.h"
+#include "grpcpp/impl/codegen/service_type.h"
+#include "grpcpp/impl/codegen/status.h"
+#include "grpcpp/impl/codegen/stub_options.h"
+#include "grpcpp/impl/codegen/sync_stream.h"
+
+#include "tensorflow/core/protobuf/master.pb.h"
+
+namespace grpc {
+class CompletionQueue;
+class Channel;
+class RpcService;
+class ServerCompletionQueue;
+class ServerContext;
+} // namespace grpc
+
+namespace tensorflow {
+
+namespace grpc {
+
+// Implementation of `tensorflow.MasterService`, based on the
+// definition in "//tensorflow/core/protobuf/master_service.proto",
+// and the gRPC generated stub and service classes.
+// See that file for the definition of methods and messages.
+class MasterService final {
+ public:
+ class StubInterface {
+ public:
+ virtual ~StubInterface() {}
+ virtual ::grpc::Status CreateSession(::grpc::ClientContext* context,
+ const CreateSessionRequest& request,
+ CreateSessionResponse* response) = 0;
+ virtual ::grpc::Status ExtendSession(::grpc::ClientContext* context,
+ const ExtendSessionRequest& request,
+ ExtendSessionResponse* response) = 0;
+ virtual ::grpc::Status PartialRunSetup(
+ ::grpc::ClientContext* context, const PartialRunSetupRequest& request,
+ PartialRunSetupResponse* response) = 0;
+ virtual ::grpc::Status RunStep(::grpc::ClientContext* context,
+ const RunStepRequest& request,
+ RunStepResponse* response) = 0;
+ virtual ::grpc::Status CloseSession(::grpc::ClientContext* context,
+ const CloseSessionRequest& request,
+ CloseSessionResponse* response) = 0;
+ virtual ::grpc::Status ListDevices(::grpc::ClientContext* context,
+ const ListDevicesRequest& request,
+ ListDevicesResponse* response) = 0;
+ virtual ::grpc::Status Reset(::grpc::ClientContext* context,
+ const ResetRequest& request,
+ ResetResponse* response) = 0;
+ virtual ::grpc::Status MakeCallable(::grpc::ClientContext* context,
+ const MakeCallableRequest& request,
+ MakeCallableResponse* response) = 0;
+ virtual ::grpc::Status RunCallable(::grpc::ClientContext* context,
+ const RunCallableRequest& request,
+ RunCallableResponse* response) = 0;
+ virtual ::grpc::Status ReleaseCallable(
+ ::grpc::ClientContext* context, const ReleaseCallableRequest& request,
+ ReleaseCallableResponse* response) = 0;
+ };
+ class Stub final : public StubInterface {
+ public:
+ Stub(const std::shared_ptr< ::grpc::ChannelInterface>& channel);
+ ::grpc::Status CreateSession(::grpc::ClientContext* context,
+ const CreateSessionRequest& request,
+ CreateSessionResponse* response) override;
+ ::grpc::Status ExtendSession(::grpc::ClientContext* context,
+ const ExtendSessionRequest& request,
+ ExtendSessionResponse* response) override;
+ ::grpc::Status PartialRunSetup(::grpc::ClientContext* context,
+ const PartialRunSetupRequest& request,
+ PartialRunSetupResponse* response) override;
+ ::grpc::Status RunStep(::grpc::ClientContext* context,
+ const RunStepRequest& request,
+ RunStepResponse* response) override;
+ ::grpc::Status CloseSession(::grpc::ClientContext* context,
+ const CloseSessionRequest& request,
+ CloseSessionResponse* response) override;
+ ::grpc::Status ListDevices(::grpc::ClientContext* context,
+ const ListDevicesRequest& request,
+ ListDevicesResponse* response) override;
+ ::grpc::Status Reset(::grpc::ClientContext* context,
+ const ResetRequest& request,
+ ResetResponse* response) override;
+ ::grpc::Status MakeCallable(::grpc::ClientContext* context,
+ const MakeCallableRequest& request,
+ MakeCallableResponse* response) override;
+ ::grpc::Status RunCallable(::grpc::ClientContext* context,
+ const RunCallableRequest& request,
+ RunCallableResponse* response) override;
+ ::grpc::Status ReleaseCallable(::grpc::ClientContext* context,
+ const ReleaseCallableRequest& request,
+ ReleaseCallableResponse* response) override;
+
+ private:
+ std::shared_ptr< ::grpc::ChannelInterface> channel_;
+ const ::grpc::internal::RpcMethod rpcmethod_CreateSession_;
+ const ::grpc::internal::RpcMethod rpcmethod_ExtendSession_;
+ const ::grpc::internal::RpcMethod rpcmethod_PartialRunSetup_;
+ const ::grpc::internal::RpcMethod rpcmethod_RunStep_;
+ const ::grpc::internal::RpcMethod rpcmethod_CloseSession_;
+ const ::grpc::internal::RpcMethod rpcmethod_ListDevices_;
+ const ::grpc::internal::RpcMethod rpcmethod_Reset_;
+ const ::grpc::internal::RpcMethod rpcmethod_MakeCallable_;
+ const ::grpc::internal::RpcMethod rpcmethod_RunCallable_;
+ const ::grpc::internal::RpcMethod rpcmethod_ReleaseCallable_;
+ };
+ static std::unique_ptr<Stub> NewStub(
+ const std::shared_ptr< ::grpc::ChannelInterface>& channel,
+ const ::grpc::StubOptions& options = ::grpc::StubOptions());
+
+ class AsyncService : public ::grpc::Service {
+ public:
+ AsyncService();
+ virtual ~AsyncService();
+ void RequestCreateSession(
+ ::grpc::ServerContext* context, CreateSessionRequest* request,
+ ::grpc::ServerAsyncResponseWriter<CreateSessionResponse>* response,
+ ::grpc::CompletionQueue* new_call_cq,
+ ::grpc::ServerCompletionQueue* notification_cq, void* tag) {
+ ::grpc::Service::RequestAsyncUnary(0, context, request, response,
+ new_call_cq, notification_cq, tag);
+ }
+ void RequestExtendSession(
+ ::grpc::ServerContext* context, ExtendSessionRequest* request,
+ ::grpc::ServerAsyncResponseWriter<ExtendSessionResponse>* response,
+ ::grpc::CompletionQueue* new_call_cq,
+ ::grpc::ServerCompletionQueue* notification_cq, void* tag) {
+ ::grpc::Service::RequestAsyncUnary(1, context, request, response,
+ new_call_cq, notification_cq, tag);
+ }
+ void RequestPartialRunSetup(
+ ::grpc::ServerContext* context, PartialRunSetupRequest* request,
+ ::grpc::ServerAsyncResponseWriter<PartialRunSetupResponse>* response,
+ ::grpc::CompletionQueue* new_call_cq,
+ ::grpc::ServerCompletionQueue* notification_cq, void* tag) {
+ ::grpc::Service::RequestAsyncUnary(2, context, request, response,
+ new_call_cq, notification_cq, tag);
+ }
+ void RequestRunStep(
+ ::grpc::ServerContext* context, RunStepRequest* request,
+ ::grpc::ServerAsyncResponseWriter<RunStepResponse>* response,
+ ::grpc::CompletionQueue* new_call_cq,
+ ::grpc::ServerCompletionQueue* notification_cq, void* tag) {
+ ::grpc::Service::RequestAsyncUnary(3, context, request, response,
+ new_call_cq, notification_cq, tag);
+ }
+ void RequestCloseSession(
+ ::grpc::ServerContext* context, CloseSessionRequest* request,
+ ::grpc::ServerAsyncResponseWriter<CloseSessionResponse>* response,
+ ::grpc::CompletionQueue* new_call_cq,
+ ::grpc::ServerCompletionQueue* notification_cq, void* tag) {
+ ::grpc::Service::RequestAsyncUnary(4, context, request, response,
+ new_call_cq, notification_cq, tag);
+ }
+ void RequestListDevices(
+ ::grpc::ServerContext* context, ListDevicesRequest* request,
+ ::grpc::ServerAsyncResponseWriter<ListDevicesResponse>* response,
+ ::grpc::CompletionQueue* new_call_cq,
+ ::grpc::ServerCompletionQueue* notification_cq, void* tag) {
+ ::grpc::Service::RequestAsyncUnary(5, context, request, response,
+ new_call_cq, notification_cq, tag);
+ }
+ void RequestReset(
+ ::grpc::ServerContext* context, ResetRequest* request,
+ ::grpc::ServerAsyncResponseWriter<ResetResponse>* response,
+ ::grpc::CompletionQueue* new_call_cq,
+ ::grpc::ServerCompletionQueue* notification_cq, void* tag) {
+ ::grpc::Service::RequestAsyncUnary(6, context, request, response,
+ new_call_cq, notification_cq, tag);
+ }
+ void RequestMakeCallable(
+ ::grpc::ServerContext* context, MakeCallableRequest* request,
+ ::grpc::ServerAsyncResponseWriter<MakeCallableResponse>* response,
+ ::grpc::CompletionQueue* new_call_cq,
+ ::grpc::ServerCompletionQueue* notification_cq, void* tag) {
+ ::grpc::Service::RequestAsyncUnary(7, context, request, response,
+ new_call_cq, notification_cq, tag);
+ }
+ void RequestRunCallable(
+ ::grpc::ServerContext* context, RunCallableRequest* request,
+ ::grpc::ServerAsyncResponseWriter<RunCallableResponse>* response,
+ ::grpc::CompletionQueue* new_call_cq,
+ ::grpc::ServerCompletionQueue* notification_cq, void* tag) {
+ ::grpc::Service::RequestAsyncUnary(8, context, request, response,
+ new_call_cq, notification_cq, tag);
+ }
+ void RequestReleaseCallable(
+ ::grpc::ServerContext* context, ReleaseCallableRequest* request,
+ ::grpc::ServerAsyncResponseWriter<ReleaseCallableResponse>* response,
+ ::grpc::CompletionQueue* new_call_cq,
+ ::grpc::ServerCompletionQueue* notification_cq, void* tag) {
+ ::grpc::Service::RequestAsyncUnary(9, context, request, response,
+ new_call_cq, notification_cq, tag);
+ }
+ };
+};
+
+} // namespace grpc
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_IMPL_H_
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc
index 6c2940553c..b832a2115c 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc
@@ -19,13 +19,13 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/call_options.h"
#include "tensorflow/core/distributed_runtime/master_interface.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/tracing.h"
#include "tensorflow/core/protobuf/master.pb.h"
-#include "tensorflow/core/protobuf/master_service.grpc.pb.h"
namespace tensorflow {
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
index 2c833d11a9..db14f6473e 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
@@ -152,16 +152,14 @@ Status GrpcServer::Init(
" was not defined in job \"",
server_def_.job_name(), "\"");
}
- const std::vector<string> hostname_port =
- str_util::Split(iter->second, ':');
- if (hostname_port.size() != 2 ||
- !strings::safe_strto32(hostname_port[1], &requested_port)) {
+ auto colon_index = iter->second.find_last_of(':');
+ if (!strings::safe_strto32(iter->second.substr(colon_index + 1),
+ &requested_port)) {
return errors::InvalidArgument(
"Could not parse port for local server from \"", iter->second,
- "\"");
- } else {
- break;
+ "\".");
}
+ break;
}
}
if (requested_port == -1) {
@@ -343,11 +341,13 @@ Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
const string host_port = channel_cache_->TranslateTask(name_prefix);
int requested_port;
- if (!strings::safe_strto32(str_util::Split(host_port, ':')[1],
+ auto colon_index = host_port.find_last_of(':');
+ if (!strings::safe_strto32(host_port.substr(colon_index + 1),
&requested_port)) {
return errors::Internal("Could not parse port for local server from \"",
- channel_cache_->TranslateTask(name_prefix), "\".");
+ host_port, "\".");
}
+
if (requested_port != bound_port_) {
return errors::InvalidArgument("Requested port ", requested_port,
" differs from expected port ", bound_port_);
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc
index fd1c150fa7..fdce1b10e0 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_session.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc
@@ -452,15 +452,12 @@ class GrpcSessionFactory : public SessionFactory {
return str_util::StartsWith(options.target, kSchemePrefix);
}
- Session* NewSession(const SessionOptions& options) override {
- std::unique_ptr<GrpcSession> ret;
- Status s = GrpcSession::Create(options, &ret);
- if (s.ok()) {
- return ret.release();
- } else {
- LOG(ERROR) << "Error during session construction: " << s.ToString();
- return nullptr;
- }
+ Status NewSession(const SessionOptions& options,
+ Session** out_session) override {
+ std::unique_ptr<GrpcSession> session;
+ TF_RETURN_IF_ERROR(GrpcSession::Create(options, &session));
+ *out_session = session.release();
+ return Status::OK();
}
// Invokes the session specific static method to reset containers.
diff --git a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc
index 5eeed6e382..45b989f6e2 100644
--- a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc
+++ b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc
@@ -99,6 +99,32 @@ void RpcCollectiveExecutorMgr::RefreshStepIdSequenceAsync(
}
}
+void RpcCollectiveExecutorMgr::GetStepSequenceAsync(
+ const GetStepSequenceRequest* request, GetStepSequenceResponse* response,
+ const StatusCallback& done) {
+ if (!group_leader_.empty()) {
+ LOG(ERROR) << "GetStepSequence called at non-group-leader";
+ done(errors::Internal("GetStepSequenceAsync called at non-group-leader"));
+ } else {
+ mutex_lock l(sequence_mu_);
+ for (int64 graph_key : request->graph_key()) {
+ auto it = sequence_table_.find(graph_key);
+ GraphKeySequence* gks = nullptr;
+ if (it == sequence_table_.end()) {
+ gks = new GraphKeySequence(graph_key);
+ gks->next_step_id_ = NewRandomStepId();
+ sequence_table_[graph_key] = gks;
+ } else {
+ gks = it->second;
+ }
+ StepSequence* ss = response->add_step_sequence();
+ ss->set_graph_key(graph_key);
+ ss->set_next_step_id(gks->next_step_id_);
+ }
+ done(Status::OK());
+ }
+}
+
Status RpcCollectiveExecutorMgr::UpdateStepSequences(
const GetStepSequenceResponse& resp) {
mutex_lock l(sequence_mu_);
diff --git a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h
index e9f3f0ebe8..c9581fa00f 100644
--- a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h
+++ b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h
@@ -42,6 +42,12 @@ class RpcCollectiveExecutorMgr : public CollectiveExecutorMgr {
virtual ~RpcCollectiveExecutorMgr();
+ // This function should only be called at the group_leader, by an RPC.
+ // Other needs for StepIds should be satisfied by NextStepId.
+ void GetStepSequenceAsync(const GetStepSequenceRequest* request,
+ GetStepSequenceResponse* response,
+ const StatusCallback& done) override;
+
void RefreshStepIdSequenceAsync(int64 graph_key,
const StatusCallback& done) override;
diff --git a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr_test.cc b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr_test.cc
index 37b83d82be..0323300fdd 100644
--- a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr_test.cc
+++ b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr_test.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/protobuf/worker.pb.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
@@ -121,4 +122,50 @@ TEST_F(RpcCollectiveExecutorMgrTest, NextStepId) {
EXPECT_GT(llabs(y - z), 3);
}
+TEST_F(RpcCollectiveExecutorMgrTest, GetStepSequence) {
+ int64 x = cme_->NextStepId(3);
+ EXPECT_EQ(x, CollectiveExecutor::kInvalidId);
+ int64 y = cme_->NextStepId(4);
+ EXPECT_EQ(y, CollectiveExecutor::kInvalidId);
+ GetStepSequenceRequest request;
+ GetStepSequenceResponse response;
+ request.add_graph_key(3);
+ request.add_graph_key(4);
+ {
+ Notification note;
+ Status status;
+ cme_->GetStepSequenceAsync(&request, &response,
+ [this, &status, &note](const Status& s) {
+ status = s;
+ note.Notify();
+ });
+ note.WaitForNotification();
+ EXPECT_TRUE(status.ok());
+ }
+ ASSERT_EQ(2, response.step_sequence_size());
+ std::unordered_map<int64, int64> values;
+ for (const auto& ss : response.step_sequence()) {
+ values[ss.graph_key()] = ss.next_step_id();
+ }
+ EXPECT_NE(values[3], CollectiveExecutor::kInvalidId);
+ EXPECT_NE(values[4], CollectiveExecutor::kInvalidId);
+ // Re-get, should be same values.
+ response.Clear();
+ {
+ Notification note;
+ Status status;
+ cme_->GetStepSequenceAsync(&request, &response,
+ [this, &status, &note](const Status& s) {
+ status = s;
+ note.Notify();
+ });
+ note.WaitForNotification();
+ EXPECT_TRUE(status.ok());
+ }
+ ASSERT_EQ(2, response.step_sequence_size());
+ for (const auto& ss : response.step_sequence()) {
+ EXPECT_EQ(values[ss.graph_key()], ss.next_step_id());
+ }
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/framework/api_def.proto b/tensorflow/core/framework/api_def.proto
index c6cda06342..f8553cf5bb 100644
--- a/tensorflow/core/framework/api_def.proto
+++ b/tensorflow/core/framework/api_def.proto
@@ -64,12 +64,6 @@ message ApiDef {
// to use a non-deprecated endpoint instead will be printed. If all
// endpoints are deprecated, set deprecation_message in ApiDef instead.
bool deprecated = 3;
- // Deprecated: set deprecated to "true" instead. We can auto-generate
- // the message.
- // If this endpoint is deprecated, set deprecation_message to a
- // message that should be logged when the endpoint is used.
- // The message should indicate alternative endpoint to use, if any.
- string deprecation_message = 2;
}
repeated Endpoint endpoint = 3;
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc
index ed3318d841..21c6940b62 100644
--- a/tensorflow/core/framework/common_shape_fns.cc
+++ b/tensorflow/core/framework/common_shape_fns.cc
@@ -1231,11 +1231,13 @@ Status ConcatV2Shape(InferenceContext* c) {
c->num_inputs() - 1 /* dim_index */);
}
-Status BroadcastBinaryOpOutputShapeFn(InferenceContext* c, int output_index) {
- ShapeHandle shape_x = c->input(0);
- ShapeHandle shape_y = c->input(1);
+Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
+ ShapeHandle shape_x,
+ ShapeHandle shape_y,
+ ShapeHandle* out) {
+ CHECK_NOTNULL(out);
if (!c->RankKnown(shape_x) || !c->RankKnown(shape_y)) {
- c->set_output(0, c->UnknownShape());
+ *out = c->UnknownShape();
return Status::OK();
}
const int32 rank_x = c->Rank(shape_x);
@@ -1293,7 +1295,7 @@ Status BroadcastBinaryOpOutputShapeFn(InferenceContext* c, int output_index) {
}
}
- c->set_output(output_index, c->MakeShape(dims));
+ *out = c->MakeShape(dims);
return Status::OK();
}
diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h
index 87bb133d92..2bedce1d6a 100644
--- a/tensorflow/core/framework/common_shape_fns.h
+++ b/tensorflow/core/framework/common_shape_fns.h
@@ -267,7 +267,22 @@ Status ConcatV2Shape(shape_inference::InferenceContext* c);
// Shape function for binary operators that broadcast their inputs
// and with output to output_index.
-Status BroadcastBinaryOpOutputShapeFn(InferenceContext* c, int output_index);
+// Note: out cannot be NULL.
+Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
+ ShapeHandle shape_x,
+ ShapeHandle shape_y,
+ ShapeHandle* out);
+
+// Shape function for binary operators that broadcast their inputs
+// and with output to output_index.
+inline Status BroadcastBinaryOpOutputShapeFn(InferenceContext* c,
+ int output_index) {
+ ShapeHandle out;
+ TF_RETURN_IF_ERROR(
+ BroadcastBinaryOpOutputShapeFnHelper(c, c->input(0), c->input(1), &out));
+ c->set_output(output_index, out);
+ return Status::OK();
+}
// Shape function for binary operators that broadcast their inputs.
// Tested by ops/math_ops_test.cc.
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc
index 8a332fa1d8..58feec90f0 100644
--- a/tensorflow/core/framework/op_kernel.cc
+++ b/tensorflow/core/framework/op_kernel.cc
@@ -263,11 +263,13 @@ OpKernelContext::OpKernelContext(Params* params, int num_outputs)
outputs_(num_outputs),
temp_memory_allocated_(0),
persistent_memory_allocated_(0) {
- Allocator* eigen_gpu_allocator = get_allocator(AllocatorAttributes());
params_->ensure_eigen_gpu_device();
- params_->device->ReinitializeGpuDevice(this, params_->eigen_gpu_device,
- params_->op_device_context,
- eigen_gpu_allocator);
+ if (params_->eigen_gpu_device != nullptr) {
+ Allocator* eigen_gpu_allocator = get_allocator(AllocatorAttributes());
+ params_->device->ReinitializeGpuDevice(this, params_->eigen_gpu_device,
+ params_->op_device_context,
+ eigen_gpu_allocator);
+ }
if (params_->record_tensor_accesses) {
referenced_tensors_.Init();
}
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index 6c4c3a2ac1..d9fe42fcbb 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -1044,7 +1044,6 @@ class OpKernelContext {
// For control flow.
FrameAndIter frame_iter() const { return params_->frame_iter; }
bool is_input_dead() const { return params_->is_input_dead; }
- bool* is_output_dead() { return &is_output_dead_; }
// May be used, e.g., to get GPU handles, etc.
// TODO(tucker): Add example usage.
@@ -1143,8 +1142,6 @@ class OpKernelContext {
// Constructed only if <params->record_tensor_accesses>.
ManualConstructor<UniqueTensorReferences> referenced_tensors_ GUARDED_BY(mu_);
- bool is_output_dead_ = false;
-
// The following data members are only used when allocation tracking is
// enabled.
mutable mutex stats_mu_;
diff --git a/tensorflow/core/graph/tensor_id.cc b/tensorflow/core/graph/tensor_id.cc
index b5c2c2aac8..80c76df255 100644
--- a/tensorflow/core/graph/tensor_id.cc
+++ b/tensorflow/core/graph/tensor_id.cc
@@ -24,9 +24,6 @@ namespace tensorflow {
TensorId::TensorId(const SafeTensorId& id) : TensorId(id.first, id.second) {}
-SafeTensorId::SafeTensorId(StringPiece str, int idx)
- : SafeTensorId(str.ToString(), idx) {}
-
SafeTensorId::SafeTensorId(const TensorId& id)
: SafeTensorId(id.first.ToString(), id.second) {}
diff --git a/tensorflow/core/graph/tensor_id.h b/tensorflow/core/graph/tensor_id.h
index b0978b4120..0ba3942618 100644
--- a/tensorflow/core/graph/tensor_id.h
+++ b/tensorflow/core/graph/tensor_id.h
@@ -62,13 +62,10 @@ TensorId ParseTensorName(StringPiece name);
struct SafeTensorId : public std::pair<string, int> {
typedef std::pair<string, int> Base;
- // Inherit the set of constructors.
- using Base::pair;
-
// NOTE(skyewm): this is required on some platforms. I'm not sure why the
- // using statement above isn't always sufficient.
+ // using "using Base::pair;" isn't always sufficient.
SafeTensorId() : Base() {}
- SafeTensorId(StringPiece str, int idx);
+ SafeTensorId(const string& str, int idx) : Base(str, idx) {}
SafeTensorId(const TensorId& id);
string ToString() const {
diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD
index b054068299..f3dc2c2091 100644
--- a/tensorflow/core/grappler/costs/BUILD
+++ b/tensorflow/core/grappler/costs/BUILD
@@ -41,6 +41,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":utils",
+ "//tensorflow/core/grappler/utils:functions",
"//tensorflow/core/grappler/utils:topological_sort",
"//tensorflow/core/grappler:graph_view",
"//tensorflow/core/grappler:op_types",
diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc
index 0c02876ac5..83a8326e79 100644
--- a/tensorflow/core/grappler/costs/graph_properties.cc
+++ b/tensorflow/core/grappler/costs/graph_properties.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/grappler/graph_view.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/grappler/utils/functions.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/lib/strings/str_util.h"
@@ -422,11 +423,108 @@ class SymbolicShapeRefiner {
return it->second.inference_context.get();
}
- // Forward the shapes from the function's fanin to the function body,
- // then call PropagateShapes.
- // Returns an error if 'node' is not a function node.
- Status UpdateFunction(const NodeDef* node, bool* refined) {
- return UpdateNode(node, refined);
+ // Forward the shapes from the function input nodes to
+ // the argument nodes (which are Placeholder nodes), then
+ // perform shape inference on the function body.
+ //
+ // Propagate shape information of final function body node
+ // to function node `node`.
+ //
+ // In the event of an error, UpdateNode will simply set `node`'s
+ // output shape to be Unknown.
+ Status UpdateFunction(const NodeDef* node) {
+ auto it = fun_to_grappler_function_item_.find(node->op());
+ if (it == fun_to_grappler_function_item_.end()) {
+ return errors::InvalidArgument(
+ node->op(), " was not previously added to SymbolicShapeRefiner.");
+ }
+
+ GrapplerFunctionItem& grappler_function_item = it->second;
+ GraphView gv(&grappler_function_item.graph);
+
+ // Forward shapes from function input nodes to argument nodes.
+ for (int i = 0; i < grappler_function_item.inputs().size(); ++i) {
+ auto& fun_input = grappler_function_item.input(i);
+ if (fun_input.placeholders.size() > 1) {
+ // TODO(jmdecker): Handle case with multiple input placeholders
+ return errors::Unimplemented(
+ "Input arguments with multiple placeholders are not yet "
+ "supported.");
+ }
+ NodeDef* fun_node = gv.GetNode(fun_input.input_name);
+ const string& input = node->input(i);
+ const string& node_name = NodeName(input);
+
+ if (IsControlInput(input)) {
+ return errors::FailedPrecondition(
+ "Function inputs should not contain control nodes.");
+ }
+
+ NodeDef* input_node = graph_.GetNode(node_name);
+ if (input_node == nullptr) {
+ return errors::FailedPrecondition(node_name,
+ " was not found in the graph.");
+ }
+
+ InferenceContext* input_inference_context = GetContext(input_node);
+ if (input_inference_context == nullptr) {
+ return errors::FailedPrecondition(
+ "Inference context has not been created for ", node_name);
+ }
+
+ int output_port_num = NodePosition(input);
+ AttrValue attr_output_shape;
+ TensorShapeProto proto;
+ const auto& handle = input_inference_context->output(output_port_num);
+ input_inference_context->ShapeHandleToProto(handle, &proto);
+ *attr_output_shape.mutable_shape() = proto;
+ (*fun_node->mutable_attr())["shape"] = attr_output_shape;
+ }
+
+ // Perform inference on function body.
+ GraphProperties gp(grappler_function_item);
+ TF_RETURN_IF_ERROR(gp.InferStatically(true));
+
+ // Add return nodes for output shapes.
+ auto ic = GetContext(node);
+ int output = 0;
+ for (auto const& out_arg : grappler_function_item.outputs()) {
+ if (out_arg.output_tensors.size() > 1) {
+ // TODO(jmdecker): Handle case of multiple output tensors
+ return errors::Unimplemented(
+ "Output arguments with multiple output tensors are not yet "
+ "supported.");
+ }
+
+ string out_tensor = out_arg.output_tensors[0];
+ auto out_tensor_pieces = str_util::Split(out_tensor, ",");
+ string node_name = out_tensor_pieces[0];
+ int port_id;
+
+ // Check if port_id was included in out_tensor
+ if (out_tensor_pieces.size() <= 1) {
+ port_id = 0;
+ } else if (!strings::safe_strto32(out_tensor_pieces[1], &port_id)) {
+ return errors::FailedPrecondition(
+ "Failed string to integer conversion for ", out_tensor_pieces[1]);
+ }
+
+ const NodeDef* retnode = gv.GetNode(node_name);
+ if (retnode == nullptr) {
+ return errors::FailedPrecondition("Unable to find return node ",
+ node_name, " for ", node->name());
+ }
+
+ auto output_properties = gp.GetOutputProperties(retnode->name());
+ auto const& outprop = output_properties[port_id];
+ const TensorShapeProto& shape = outprop.shape();
+ ShapeHandle out;
+ TF_RETURN_IF_ERROR(ic->MakeShapeFromShapeProto(shape, &out));
+ ic->set_output(output, out);
+ output++;
+ }
+
+ return Status::OK();
}
Status UpdateNode(const NodeDef* node, bool* refined) {
@@ -436,6 +534,7 @@ class SymbolicShapeRefiner {
node_context = CHECK_NOTNULL(GetNodeContext(node));
*refined = true;
}
+
// Check if the shapes of the nodes in the fan-in of this node have changed,
// and if they have, update the node input shapes.
InferenceContext* inference_context = node_context->inference_context.get();
@@ -455,7 +554,8 @@ class SymbolicShapeRefiner {
if (c == nullptr) {
return errors::FailedPrecondition(
"Input ", dst_input, " ('", input->name(), "') for '",
- node->name(), "' was not previously added to ShapeRefiner.");
+ node->name(),
+ "' was not previously added to SymbolicShapeRefiner.");
}
if (IsConstant(*input)) {
@@ -565,6 +665,21 @@ class SymbolicShapeRefiner {
node_context->inference_context->set_input_tensors_as_shapes(
input_tensors_as_shapes);
+ // Properly handle function nodes.
+ if (node_context->op_data && node_context->op_data->is_function_op) {
+ // TODO(jmdecker): Detect if the input shapes have changed for this
+ // function. Note that when we hit a function call node, refined will be
+ // true, as the updates to the call node will have changed, even if it's
+ // the same function being called twice with the same input shapes.
+ // Example: simple_function.pbtxt
+ if (UpdateFunction(node).ok()) {
+ return Status::OK();
+ } else {
+ VLOG(1) << "UpdateFunction failed for " << node->op()
+ << ". Defaulting to ShapeUnknown.";
+ }
+ }
+
// Update the shapes of the outputs.
return InferShapes(*node, node_context);
}
@@ -681,7 +796,39 @@ class SymbolicShapeRefiner {
return true;
}
- Status AddFunction(const NodeDef* node) { return Status::OK(); }
+ Status AddFunction(const NodeDef* function_node) {
+ auto it = fun_to_grappler_function_item_.find(function_node->op());
+ if (it != fun_to_grappler_function_item_.end()) {
+ return Status::OK();
+ }
+
+ const FunctionDef* function_def =
+ CHECK_NOTNULL(function_library_.Find(function_node->op()));
+
+ GrapplerFunctionItem grappler_function_item;
+ TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(
+ *function_def, function_library_, &grappler_function_item));
+
+ if (grappler_function_item.inputs().size() > function_node->input_size()) {
+ return errors::FailedPrecondition(
+ "Function input size should be smaller than node input size.");
+ }
+
+ for (int i = grappler_function_item.inputs().size();
+ i < function_node->input_size(); ++i) {
+ const string& input = function_node->input(i);
+ if (!IsControlInput(input)) {
+ return errors::FailedPrecondition(
+ "Found regular input (", input,
+ ") instead of control nodes for node ", function_node->name());
+ }
+ }
+
+ fun_to_grappler_function_item_[function_def->signature().name()] =
+ grappler_function_item;
+
+ return Status::OK();
+ }
Status AddNode(const NodeDef* node) {
NodeContext& node_ctx = node_to_context_[node];
@@ -911,6 +1058,8 @@ class SymbolicShapeRefiner {
std::unordered_map<const NodeDef*, NodeContext> node_to_context_;
std::unordered_map<ShapeId, ShapeHandle, HashShapeId> unknown_shapes_;
std::unordered_map<DimId, DimensionHandle, HashDimId> unknown_dims_;
+ std::unordered_map<string, GrapplerFunctionItem>
+ fun_to_grappler_function_item_;
FunctionLibraryDefinition function_library_;
const std::unordered_map<string, std::unordered_set<int>>& fed_ports_;
};
@@ -1082,13 +1231,9 @@ Status GraphProperties::UpdateShapes(
// Set shapes and types of Queue ops, if needed.
TF_RETURN_IF_ERROR(UpdateQueue(n, shape_refiner, new_shapes));
} else {
- auto c = shape_refiner->GetNodeContext(n);
- if (c && c->op_data && c->op_data->is_function_op) {
- TF_RETURN_IF_ERROR(shape_refiner->UpdateFunction(n, new_shapes));
- } else {
- // Rely on regular TF shape refinement for all the other nodes.
- TF_RETURN_IF_ERROR(shape_refiner->UpdateNode(n, new_shapes));
- }
+ // Rely on regular TF shape refinement for all the other nodes.
+ // UpdateNode calls UpdateFunction if a function node is detected.
+ TF_RETURN_IF_ERROR(shape_refiner->UpdateNode(n, new_shapes));
}
return Status::OK();
}
diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc
index aa787ae620..1be19d291a 100644
--- a/tensorflow/core/grappler/costs/graph_properties_test.cc
+++ b/tensorflow/core/grappler/costs/graph_properties_test.cc
@@ -783,7 +783,7 @@ TEST_F(GraphPropertiesTest, InferRestoreOpShape_WithTwoNodesShareSameOutput) {
EXPECT_EQ("float: [128,256]", PropToString(prop));
}
-TEST_F(GraphPropertiesTest, FunctionStaticShapeInference) {
+TEST_F(GraphPropertiesTest, SimpleFunctionStaticShapeInference) {
// Test graph produced in python using:
/*
@function.Defun(*[tf.float32] * 2, noinline=True)
@@ -796,7 +796,6 @@ TEST_F(GraphPropertiesTest, FunctionStaticShapeInference) {
z = MyAdd(x, y)
z = MyAdd(x, z)
*/
- // Check that the shape inference code infers what it can.
GrapplerItem item;
string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
"simple_function.pbtxt");
@@ -806,15 +805,258 @@ TEST_F(GraphPropertiesTest, FunctionStaticShapeInference) {
const auto out_props = properties.GetOutputProperties("MyAdd_55e046a8");
const OpInfo::TensorProperties& out_prop = out_props[0];
EXPECT_EQ(DT_FLOAT, out_prop.dtype());
- EXPECT_TRUE(out_prop.shape().unknown_rank());
+ EXPECT_FALSE(out_prop.shape().unknown_rank());
+ EXPECT_EQ(2, out_prop.shape().dim_size());
+ EXPECT_EQ(1, out_prop.shape().dim(0).size());
+ EXPECT_EQ(2, out_prop.shape().dim(1).size());
const auto in_props = properties.GetInputProperties("MyAdd_55e046a8");
+ EXPECT_EQ(2, in_props.size());
+
+ const OpInfo::TensorProperties& in_prop = in_props[0];
+ EXPECT_EQ(DT_FLOAT, in_prop.dtype());
+ EXPECT_FALSE(in_prop.shape().unknown_rank());
+ EXPECT_EQ(2, in_prop.shape().dim_size());
+ EXPECT_EQ(1, in_prop.shape().dim(0).size());
+ EXPECT_EQ(2, in_prop.shape().dim(1).size());
+
+ const OpInfo::TensorProperties& in_prop1 = in_props[1];
+ EXPECT_EQ(DT_FLOAT, in_prop1.dtype());
+ EXPECT_FALSE(in_prop1.shape().unknown_rank());
+ EXPECT_EQ(2, in_prop1.shape().dim_size());
+ EXPECT_EQ(1, in_prop1.shape().dim(0).size());
+ EXPECT_EQ(2, in_prop1.shape().dim(1).size());
+}
+
+TEST_F(GraphPropertiesTest, LargeFunctionStaticShapeInference) {
+ GrapplerItem item;
+ string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
+ "large_function_graph.pbtxt");
+ TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(false));
+
+ const auto out_props = properties.GetOutputProperties("y0");
+ EXPECT_EQ(2, out_props.size());
+
+ const OpInfo::TensorProperties& out_prop0 = out_props[0];
+ EXPECT_EQ(DT_FLOAT, out_prop0.dtype());
+ EXPECT_EQ(4, out_prop0.shape().dim_size());
+ EXPECT_EQ(128, out_prop0.shape().dim(0).size());
+ EXPECT_EQ(112, out_prop0.shape().dim(1).size());
+ EXPECT_EQ(112, out_prop0.shape().dim(2).size());
+ EXPECT_EQ(64, out_prop0.shape().dim(3).size());
+
+ const OpInfo::TensorProperties& out_prop1 = out_props[1];
+ EXPECT_EQ(DT_FLOAT, out_prop1.dtype());
+ EXPECT_EQ(128, out_prop1.shape().dim(0).size());
+ EXPECT_EQ(112, out_prop1.shape().dim(1).size());
+ EXPECT_EQ(112, out_prop1.shape().dim(2).size());
+ EXPECT_EQ(24, out_prop1.shape().dim(3).size());
+
+ const auto in_props = properties.GetInputProperties("y0");
+ EXPECT_EQ(4, in_props.size());
+
+ const OpInfo::TensorProperties& in_prop0 = in_props[0];
+ EXPECT_EQ(DT_FLOAT, in_prop0.dtype());
+ EXPECT_EQ(1, in_prop0.shape().dim_size());
+ EXPECT_EQ(64, in_prop0.shape().dim(0).size());
+
+ const OpInfo::TensorProperties& in_prop1 = in_props[1];
+ EXPECT_EQ(DT_FLOAT, in_prop1.dtype());
+ EXPECT_EQ(4, in_prop1.shape().dim_size());
+ EXPECT_EQ(1, in_prop1.shape().dim(0).size());
+ EXPECT_EQ(1, in_prop1.shape().dim(1).size());
+ EXPECT_EQ(24, in_prop1.shape().dim(2).size());
+ EXPECT_EQ(64, in_prop1.shape().dim(3).size());
+
+ const OpInfo::TensorProperties& in_prop2 = in_props[2];
+ EXPECT_EQ(DT_FLOAT, in_prop2.dtype());
+ EXPECT_EQ(4, in_prop2.shape().dim_size());
+ EXPECT_EQ(128, in_prop2.shape().dim(0).size());
+ EXPECT_EQ(224, in_prop2.shape().dim(1).size());
+ EXPECT_EQ(224, in_prop2.shape().dim(2).size());
+ EXPECT_EQ(3, in_prop2.shape().dim(3).size());
+
+ const OpInfo::TensorProperties& in_prop3 = in_props[3];
+ EXPECT_EQ(DT_FLOAT, in_prop3.dtype());
+ EXPECT_EQ(4, in_prop3.shape().dim_size());
+ EXPECT_EQ(7, in_prop3.shape().dim(0).size());
+ EXPECT_EQ(7, in_prop3.shape().dim(1).size());
+ EXPECT_EQ(3, in_prop3.shape().dim(2).size());
+ EXPECT_EQ(8, in_prop3.shape().dim(3).size());
+}
+
+TEST_F(GraphPropertiesTest, FunctionWithErrorStaticShapeInference) {
+ GrapplerItem item;
+ string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
+ "function_error.pbtxt");
+ TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(false));
+
+ const auto out_props = properties.GetOutputProperties("MyAdd_yabA4wXEdM4");
+ EXPECT_EQ(1, out_props.size());
+
+ const OpInfo::TensorProperties& out_prop = out_props[0];
+ EXPECT_EQ(DT_FLOAT, out_prop.dtype());
+ EXPECT_TRUE(out_prop.shape().unknown_rank());
+
+ const auto in_props = properties.GetInputProperties("MyAdd_yabA4wXEdM4");
+ EXPECT_EQ(2, in_props.size());
+
const OpInfo::TensorProperties& in_prop = in_props[0];
EXPECT_EQ(DT_FLOAT, in_prop.dtype());
EXPECT_FALSE(in_prop.shape().unknown_rank());
EXPECT_EQ(2, in_prop.shape().dim_size());
EXPECT_EQ(1, in_prop.shape().dim(0).size());
EXPECT_EQ(2, in_prop.shape().dim(1).size());
+
+ const OpInfo::TensorProperties& in_prop1 = in_props[1];
+ EXPECT_EQ(DT_FLOAT, in_prop1.dtype());
+ EXPECT_FALSE(in_prop1.shape().unknown_rank());
+ EXPECT_EQ(2, in_prop1.shape().dim_size());
+ EXPECT_EQ(1, in_prop1.shape().dim(0).size());
+ EXPECT_EQ(2, in_prop1.shape().dim(1).size());
+}
+
+TEST_F(GraphPropertiesTest, FunctionSwitchStaticShapeInference) {
+ // Test graph produced in python using:
+ /*
+ @function.Defun(*[tf.float32] * 2, noinline=True)
+ def MyAdd(x, y):
+ return tf.add(x, y)
+
+ with tf.Graph().as_default():
+ x = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
+ y = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
+ z = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
+ z2 = MyAdd(tf.case([(tf.less(0, 1), x)], default=y), z)
+ */
+ GrapplerItem item;
+ string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
+ "function_switch.pbtxt");
+ TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(false));
+ const auto out_props = properties.GetOutputProperties("MyAdd_MPaeanipb7o");
+ const OpInfo::TensorProperties& out_prop = out_props[0];
+ EXPECT_EQ(DT_FLOAT, out_prop.dtype());
+ EXPECT_FALSE(out_prop.shape().unknown_rank());
+ EXPECT_EQ(2, out_prop.shape().dim_size());
+ EXPECT_EQ(1, out_prop.shape().dim(0).size());
+ EXPECT_EQ(2, out_prop.shape().dim(1).size());
+
+ const auto in_props = properties.GetInputProperties("MyAdd_MPaeanipb7o");
+ EXPECT_EQ(2, in_props.size());
+
+ const OpInfo::TensorProperties& in_prop = in_props[0];
+ EXPECT_EQ(DT_FLOAT, in_prop.dtype());
+ EXPECT_FALSE(in_prop.shape().unknown_rank());
+ EXPECT_EQ(2, in_prop.shape().dim_size());
+ EXPECT_EQ(1, in_prop.shape().dim(0).size());
+ EXPECT_EQ(2, in_prop.shape().dim(1).size());
+
+ const OpInfo::TensorProperties& in_prop1 = in_props[1];
+ EXPECT_EQ(DT_FLOAT, in_prop1.dtype());
+ EXPECT_FALSE(in_prop1.shape().unknown_rank());
+ EXPECT_EQ(2, in_prop1.shape().dim_size());
+ EXPECT_EQ(1, in_prop1.shape().dim(0).size());
+ EXPECT_EQ(2, in_prop1.shape().dim(1).size());
+}
+
+TEST_F(GraphPropertiesTest, FunctionSwitch2StaticShapeInference) {
+ // Test graph produced in python using:
+ /*
+ @function.Defun(*[tf.float32] * 2, noinline=True)
+ def MyAdd(x, y):
+ return tf.add(x, y)
+
+ with tf.Graph().as_default():
+ x = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
+ y = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
+ z = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
+ z2 = MyAdd(tf.case([(tf.less(1, 0), x)], default=y), z)
+ */
+ GrapplerItem item;
+ string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
+ "function_switch_2.pbtxt");
+ TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(false));
+ const auto out_props = properties.GetOutputProperties("MyAdd_MPaeanipb7o");
+ const OpInfo::TensorProperties& out_prop = out_props[0];
+ EXPECT_EQ(DT_FLOAT, out_prop.dtype());
+ EXPECT_FALSE(out_prop.shape().unknown_rank());
+ EXPECT_EQ(2, out_prop.shape().dim_size());
+ EXPECT_EQ(1, out_prop.shape().dim(0).size());
+ EXPECT_EQ(2, out_prop.shape().dim(1).size());
+
+ const auto in_props = properties.GetInputProperties("MyAdd_MPaeanipb7o");
+ EXPECT_EQ(2, in_props.size());
+
+ const OpInfo::TensorProperties& in_prop = in_props[0];
+ EXPECT_EQ(DT_FLOAT, in_prop.dtype());
+ EXPECT_FALSE(in_prop.shape().unknown_rank());
+ EXPECT_EQ(2, in_prop.shape().dim_size());
+ EXPECT_EQ(1, in_prop.shape().dim(0).size());
+ EXPECT_EQ(2, in_prop.shape().dim(1).size());
+
+ const OpInfo::TensorProperties& in_prop1 = in_props[1];
+ EXPECT_EQ(DT_FLOAT, in_prop1.dtype());
+ EXPECT_FALSE(in_prop1.shape().unknown_rank());
+ EXPECT_EQ(2, in_prop1.shape().dim_size());
+ EXPECT_EQ(1, in_prop1.shape().dim(0).size());
+ EXPECT_EQ(2, in_prop1.shape().dim(1).size());
+}
+
+TEST_F(GraphPropertiesTest, FunctionSwitchShapesStaticShapeInference) {
+ // Test graph produced in python using:
+ /*
+ @function.Defun(*[tf.float32] * 2, noinline=True)
+ def MyAdd(x, y):
+ a = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
+ b = tf.constant(2.0, shape=[1, 3], dtype=tf.float32)
+ c = tf.add(x, a)
+ d = tf.add(y, b)
+ return c
+
+ with tf.Graph().as_default():
+ x = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
+ y = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
+ z = tf.constant(2.0, shape=[1, 3], dtype=tf.float32)
+ z2 = MyAdd(tf.case([(tf.less(1, 0), x)], default=y), z)
+ */
+ GrapplerItem item;
+ string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
+ "function_switch_shapes.pbtxt");
+ TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(false));
+ const auto out_props = properties.GetOutputProperties("MyAdd_lEKAAnIwI5I");
+ const OpInfo::TensorProperties& out_prop = out_props[0];
+ EXPECT_EQ(DT_FLOAT, out_prop.dtype());
+ EXPECT_FALSE(out_prop.shape().unknown_rank());
+ EXPECT_EQ(2, out_prop.shape().dim_size());
+ EXPECT_EQ(1, out_prop.shape().dim(0).size());
+ EXPECT_EQ(2, out_prop.shape().dim(1).size());
+
+ const auto in_props = properties.GetInputProperties("MyAdd_lEKAAnIwI5I");
+ EXPECT_EQ(2, in_props.size());
+
+ const OpInfo::TensorProperties& in_prop = in_props[0];
+ EXPECT_EQ(DT_FLOAT, in_prop.dtype());
+ EXPECT_FALSE(in_prop.shape().unknown_rank());
+ EXPECT_EQ(2, in_prop.shape().dim_size());
+ EXPECT_EQ(1, in_prop.shape().dim(0).size());
+ EXPECT_EQ(2, in_prop.shape().dim(1).size());
+
+ const OpInfo::TensorProperties& in_prop1 = in_props[1];
+ EXPECT_EQ(DT_FLOAT, in_prop1.dtype());
+ EXPECT_FALSE(in_prop1.shape().unknown_rank());
+ EXPECT_EQ(2, in_prop1.shape().dim_size());
+ EXPECT_EQ(1, in_prop1.shape().dim(0).size());
+ EXPECT_EQ(3, in_prop1.shape().dim(1).size());
}
TEST_F(GraphPropertiesTest, SymbolicShapes) {
diff --git a/tensorflow/core/grappler/costs/graph_properties_testdata/function_error.pbtxt b/tensorflow/core/grappler/costs/graph_properties_testdata/function_error.pbtxt
new file mode 100644
index 0000000000..c3f0a6c95d
--- /dev/null
+++ b/tensorflow/core/grappler/costs/graph_properties_testdata/function_error.pbtxt
@@ -0,0 +1,117 @@
+node {
+ name: "Const"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 2
+ }
+ }
+ float_val: 2.0
+ }
+ }
+ }
+}
+node {
+ name: "Const_1"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 2
+ }
+ }
+ float_val: 2.0
+ }
+ }
+ }
+}
+node {
+ name: "MyAdd_yabA4wXEdM4"
+ op: "MyAdd_yabA4wXEdM4"
+ input: "Const"
+ input: "Const_1"
+}
+library {
+ function {
+ signature {
+ name: "MyAdd_yabA4wXEdM4"
+ input_arg {
+ name: "x"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "y"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "add_1"
+ type: DT_FLOAT
+ }
+ }
+ node_def {
+ name: "Add"
+ op: "Add"
+ input: "x"
+ input: "Add:z:0"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ node_def {
+ name: "Add_1"
+ op: "Add"
+ input: "Add:z:0"
+ input: "y"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ ret {
+ key: "add_1"
+ value: "Add_1:z:0"
+ }
+ attr {
+ key: "_noinline"
+ value {
+ b: true
+ }
+ }
+ }
+}
+versions {
+ producer: 26
+ min_consumer: 12
+}
diff --git a/tensorflow/core/grappler/costs/graph_properties_testdata/function_switch.pbtxt b/tensorflow/core/grappler/costs/graph_properties_testdata/function_switch.pbtxt
new file mode 100644
index 0000000000..d6d856ce41
--- /dev/null
+++ b/tensorflow/core/grappler/costs/graph_properties_testdata/function_switch.pbtxt
@@ -0,0 +1,251 @@
+node {
+ name: "Const"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 2
+ }
+ }
+ float_val: 2.0
+ }
+ }
+ }
+}
+node {
+ name: "Less/x"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 0
+ }
+ }
+ }
+}
+node {
+ name: "Less/y"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 1
+ }
+ }
+ }
+}
+node {
+ name: "Less"
+ op: "Less"
+ input: "Less/x"
+ input: "Less/y"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+}
+node {
+ name: "case/cond/Switch"
+ op: "Switch"
+ input: "Less"
+ input: "Less"
+ attr {
+ key: "T"
+ value {
+ type: DT_BOOL
+ }
+ }
+}
+node {
+ name: "case/cond/switch_t"
+ op: "Identity"
+ input: "case/cond/Switch:1"
+ attr {
+ key: "T"
+ value {
+ type: DT_BOOL
+ }
+ }
+}
+node {
+ name: "case/cond/switch_f"
+ op: "Identity"
+ input: "case/cond/Switch"
+ attr {
+ key: "T"
+ value {
+ type: DT_BOOL
+ }
+ }
+}
+node {
+ name: "case/cond/pred_id"
+ op: "Identity"
+ input: "Less"
+ attr {
+ key: "T"
+ value {
+ type: DT_BOOL
+ }
+ }
+}
+node {
+ name: "case/cond/Const"
+ op: "Const"
+ input: "^case/cond/switch_t"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 2
+ }
+ }
+ float_val: 2.0
+ }
+ }
+ }
+}
+node {
+ name: "case/cond/Const_1"
+ op: "Const"
+ input: "^case/cond/switch_f"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 2
+ }
+ }
+ float_val: 2.0
+ }
+ }
+ }
+}
+node {
+ name: "case/cond/Merge"
+ op: "Merge"
+ input: "case/cond/Const_1"
+ input: "case/cond/Const"
+ attr {
+ key: "N"
+ value {
+ i: 2
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "MyAdd_MPaeanipb7o"
+ op: "MyAdd_MPaeanipb7o"
+ input: "case/cond/Merge"
+ input: "Const"
+}
+library {
+ function {
+ signature {
+ name: "MyAdd_MPaeanipb7o"
+ input_arg {
+ name: "x"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "y"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "Add"
+ type: DT_FLOAT
+ }
+ }
+ node_def {
+ name: "Add"
+ op: "Add"
+ input: "x"
+ input: "y"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ ret {
+ key: "Add"
+ value: "Add:z:0"
+ }
+ attr {
+ key: "_noinline"
+ value {
+ b: true
+ }
+ }
+ }
+}
+versions {
+ producer: 26
+ min_consumer: 12
+}
diff --git a/tensorflow/core/grappler/costs/graph_properties_testdata/function_switch_2.pbtxt b/tensorflow/core/grappler/costs/graph_properties_testdata/function_switch_2.pbtxt
new file mode 100644
index 0000000000..e57d9d7076
--- /dev/null
+++ b/tensorflow/core/grappler/costs/graph_properties_testdata/function_switch_2.pbtxt
@@ -0,0 +1,251 @@
+node {
+ name: "Const"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 2
+ }
+ }
+ float_val: 2.0
+ }
+ }
+ }
+}
+node {
+ name: "Less/x"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 1
+ }
+ }
+ }
+}
+node {
+ name: "Less/y"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 0
+ }
+ }
+ }
+}
+node {
+ name: "Less"
+ op: "Less"
+ input: "Less/x"
+ input: "Less/y"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+}
+node {
+ name: "case/cond/Switch"
+ op: "Switch"
+ input: "Less"
+ input: "Less"
+ attr {
+ key: "T"
+ value {
+ type: DT_BOOL
+ }
+ }
+}
+node {
+ name: "case/cond/switch_t"
+ op: "Identity"
+ input: "case/cond/Switch:1"
+ attr {
+ key: "T"
+ value {
+ type: DT_BOOL
+ }
+ }
+}
+node {
+ name: "case/cond/switch_f"
+ op: "Identity"
+ input: "case/cond/Switch"
+ attr {
+ key: "T"
+ value {
+ type: DT_BOOL
+ }
+ }
+}
+node {
+ name: "case/cond/pred_id"
+ op: "Identity"
+ input: "Less"
+ attr {
+ key: "T"
+ value {
+ type: DT_BOOL
+ }
+ }
+}
+node {
+ name: "case/cond/Const"
+ op: "Const"
+ input: "^case/cond/switch_t"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 2
+ }
+ }
+ float_val: 2.0
+ }
+ }
+ }
+}
+node {
+ name: "case/cond/Const_1"
+ op: "Const"
+ input: "^case/cond/switch_f"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 2
+ }
+ }
+ float_val: 2.0
+ }
+ }
+ }
+}
+node {
+ name: "case/cond/Merge"
+ op: "Merge"
+ input: "case/cond/Const_1"
+ input: "case/cond/Const"
+ attr {
+ key: "N"
+ value {
+ i: 2
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "MyAdd_MPaeanipb7o"
+ op: "MyAdd_MPaeanipb7o"
+ input: "case/cond/Merge"
+ input: "Const"
+}
+library {
+ function {
+ signature {
+ name: "MyAdd_MPaeanipb7o"
+ input_arg {
+ name: "x"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "y"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "Add"
+ type: DT_FLOAT
+ }
+ }
+ node_def {
+ name: "Add"
+ op: "Add"
+ input: "x"
+ input: "y"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ ret {
+ key: "Add"
+ value: "Add:z:0"
+ }
+ attr {
+ key: "_noinline"
+ value {
+ b: true
+ }
+ }
+ }
+}
+versions {
+ producer: 26
+ min_consumer: 12
+}
diff --git a/tensorflow/core/grappler/costs/graph_properties_testdata/function_switch_shapes.pbtxt b/tensorflow/core/grappler/costs/graph_properties_testdata/function_switch_shapes.pbtxt
new file mode 100644
index 0000000000..e9afa91886
--- /dev/null
+++ b/tensorflow/core/grappler/costs/graph_properties_testdata/function_switch_shapes.pbtxt
@@ -0,0 +1,317 @@
+node {
+ name: "Const"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 3
+ }
+ }
+ float_val: 2.0
+ }
+ }
+ }
+}
+node {
+ name: "Less/x"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 1
+ }
+ }
+ }
+}
+node {
+ name: "Less/y"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 0
+ }
+ }
+ }
+}
+node {
+ name: "Less"
+ op: "Less"
+ input: "Less/x"
+ input: "Less/y"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+}
+node {
+ name: "case/cond/Switch"
+ op: "Switch"
+ input: "Less"
+ input: "Less"
+ attr {
+ key: "T"
+ value {
+ type: DT_BOOL
+ }
+ }
+}
+node {
+ name: "case/cond/switch_t"
+ op: "Identity"
+ input: "case/cond/Switch:1"
+ attr {
+ key: "T"
+ value {
+ type: DT_BOOL
+ }
+ }
+}
+node {
+ name: "case/cond/switch_f"
+ op: "Identity"
+ input: "case/cond/Switch"
+ attr {
+ key: "T"
+ value {
+ type: DT_BOOL
+ }
+ }
+}
+node {
+ name: "case/cond/pred_id"
+ op: "Identity"
+ input: "Less"
+ attr {
+ key: "T"
+ value {
+ type: DT_BOOL
+ }
+ }
+}
+node {
+ name: "case/cond/Const"
+ op: "Const"
+ input: "^case/cond/switch_t"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 2
+ }
+ }
+ float_val: 2.0
+ }
+ }
+ }
+}
+node {
+ name: "case/cond/Const_1"
+ op: "Const"
+ input: "^case/cond/switch_f"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 2
+ }
+ }
+ float_val: 2.0
+ }
+ }
+ }
+}
+node {
+ name: "case/cond/Merge"
+ op: "Merge"
+ input: "case/cond/Const_1"
+ input: "case/cond/Const"
+ attr {
+ key: "N"
+ value {
+ i: 2
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "MyAdd_lEKAAnIwI5I"
+ op: "MyAdd_lEKAAnIwI5I"
+ input: "case/cond/Merge"
+ input: "Const"
+}
+library {
+ function {
+ signature {
+ name: "MyAdd_lEKAAnIwI5I"
+ input_arg {
+ name: "x"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "y"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "Add"
+ type: DT_FLOAT
+ }
+ }
+ node_def {
+ name: "Const"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 2
+ }
+ }
+ float_val: 2.0
+ }
+ }
+ }
+ }
+ node_def {
+ name: "Const_1"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 3
+ }
+ }
+ float_val: 2.0
+ }
+ }
+ }
+ }
+ node_def {
+ name: "Add"
+ op: "Add"
+ input: "x"
+ input: "Const:output:0"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ node_def {
+ name: "Add_1"
+ op: "Add"
+ input: "y"
+ input: "Const_1:output:0"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ ret {
+ key: "Add"
+ value: "Add:z:0"
+ }
+ attr {
+ key: "_noinline"
+ value {
+ b: true
+ }
+ }
+ }
+}
+versions {
+ producer: 26
+ min_consumer: 12
+}
diff --git a/tensorflow/core/grappler/costs/graph_properties_testdata/large_function_graph.pbtxt b/tensorflow/core/grappler/costs/graph_properties_testdata/large_function_graph.pbtxt
new file mode 100644
index 0000000000..415c347a1d
--- /dev/null
+++ b/tensorflow/core/grappler/costs/graph_properties_testdata/large_function_graph.pbtxt
@@ -0,0 +1,597 @@
+node {
+ name: "Const/Const"
+ op: "Const"
+ device: "/cpu:0"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ int_val: 64
+ }
+ }
+ }
+}
+node {
+ name: "input_0_0"
+ op: "RandomUniform"
+ input: "Const/Const"
+ device: "/cpu:0"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "seed"
+ value {
+ i: 0
+ }
+ }
+ attr {
+ key: "seed2"
+ value {
+ i: 0
+ }
+ }
+}
+node {
+ name: "Const_1/Const"
+ op: "Const"
+ device: "/cpu:0"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 4
+ }
+ }
+ tensor_content: "\001\000\000\000\001\000\000\000\030\000\000\000@\000\000\000"
+ }
+ }
+ }
+}
+node {
+ name: "input_1_0"
+ op: "RandomUniform"
+ input: "Const_1/Const"
+ device: "/cpu:0"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "seed"
+ value {
+ i: 0
+ }
+ }
+ attr {
+ key: "seed2"
+ value {
+ i: 0
+ }
+ }
+}
+node {
+ name: "Const_2/Const"
+ op: "Const"
+ device: "/cpu:0"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 4
+ }
+ }
+ tensor_content: "\200\000\000\000\340\000\000\000\340\000\000\000\003\000\000\000"
+ }
+ }
+ }
+}
+node {
+ name: "input_2_0"
+ op: "RandomUniform"
+ input: "Const_2/Const"
+ device: "/cpu:0"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "seed"
+ value {
+ i: 0
+ }
+ }
+ attr {
+ key: "seed2"
+ value {
+ i: 0
+ }
+ }
+}
+node {
+ name: "Const_3/Const"
+ op: "Const"
+ device: "/cpu:0"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 4
+ }
+ }
+ tensor_content: "\007\000\000\000\007\000\000\000\003\000\000\000\010\000\000\000"
+ }
+ }
+ }
+}
+node {
+ name: "input_3_0"
+ op: "RandomUniform"
+ input: "Const_3/Const"
+ device: "/cpu:0"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "seed"
+ value {
+ i: 0
+ }
+ }
+ attr {
+ key: "seed2"
+ value {
+ i: 0
+ }
+ }
+}
+node {
+ name: "y0"
+ op: "BiasAddx1_Conv2Dx1_DepthwiseConv2dNativex1_Relux1_95"
+ input: "input_0_0"
+ input: "input_1_0"
+ input: "input_2_0"
+ input: "input_3_0"
+ input: "^input_0_0"
+ input: "^input_1_0"
+ input: "^input_2_0"
+ input: "^input_3_0"
+ device: "/cpu:0"
+}
+node {
+ name: "shape"
+ op: "Shape"
+ input: "y0"
+ input: "^input_0_0"
+ input: "^input_1_0"
+ input: "^input_2_0"
+ input: "^input_3_0"
+ device: "/cpu:0"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "out_type"
+ value {
+ type: DT_INT32
+ }
+ }
+}
+node {
+ name: "zeros"
+ op: "ZerosLike"
+ input: "shape"
+ input: "^input_0_0"
+ input: "^input_1_0"
+ input: "^input_2_0"
+ input: "^input_3_0"
+ device: "/cpu:0"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+}
+node {
+ name: "ones"
+ op: "OnesLike"
+ input: "shape"
+ input: "^input_0_0"
+ input: "^input_1_0"
+ input: "^input_2_0"
+ input: "^input_3_0"
+ device: "/cpu:0"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+}
+node {
+ name: "slice_0"
+ op: "Slice"
+ input: "y0"
+ input: "zeros"
+ input: "ones"
+ input: "^input_0_0"
+ input: "^input_1_0"
+ input: "^input_2_0"
+ input: "^input_3_0"
+ device: "/cpu:0"
+ attr {
+ key: "Index"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "identity_0"
+ op: "Identity"
+ input: "slice_0"
+ input: "^input_0_0"
+ input: "^input_1_0"
+ input: "^input_2_0"
+ input: "^input_3_0"
+ device: "/cpu:0"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "shape_1"
+ op: "Shape"
+ input: "y0:1"
+ input: "^input_0_0"
+ input: "^input_1_0"
+ input: "^input_2_0"
+ input: "^input_3_0"
+ device: "/cpu:0"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "out_type"
+ value {
+ type: DT_INT32
+ }
+ }
+}
+node {
+ name: "zeros_1"
+ op: "ZerosLike"
+ input: "shape_1"
+ input: "^input_0_0"
+ input: "^input_1_0"
+ input: "^input_2_0"
+ input: "^input_3_0"
+ device: "/cpu:0"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+}
+node {
+ name: "ones_1"
+ op: "OnesLike"
+ input: "shape_1"
+ input: "^input_0_0"
+ input: "^input_1_0"
+ input: "^input_2_0"
+ input: "^input_3_0"
+ device: "/cpu:0"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+}
+node {
+ name: "slice_1"
+ op: "Slice"
+ input: "y0:1"
+ input: "zeros_1"
+ input: "ones_1"
+ input: "^input_0_0"
+ input: "^input_1_0"
+ input: "^input_2_0"
+ input: "^input_3_0"
+ device: "/cpu:0"
+ attr {
+ key: "Index"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "identity_1"
+ op: "Identity"
+ input: "slice_1"
+ input: "^input_0_0"
+ input: "^input_1_0"
+ input: "^input_2_0"
+ input: "^input_3_0"
+ device: "/cpu:0"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+library {
+ function {
+ signature {
+ name: "BiasAddx1_Conv2Dx1_DepthwiseConv2dNativex1_Relux1_95"
+ input_arg {
+ name: "InceptionV2/Conv2d_1a_7x7/biases/read"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "InceptionV2/Conv2d_1a_7x7/pointwise_weights/read"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "random_uniform"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "InceptionV2/Conv2d_1a_7x7/depthwise_weights/read"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "InceptionV2/InceptionV2/Conv2d_1a_7x7/Relu"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "InceptionV2/InceptionV2/Conv2d_1a_7x7/separable_conv2d/depthwise"
+ type: DT_FLOAT
+ }
+ }
+ node_def {
+ name: "InceptionV2/InceptionV2/Conv2d_1a_7x7/BiasAdd"
+ op: "BiasAdd"
+ input: "InceptionV2/InceptionV2/Conv2d_1a_7x7/separable_conv2d:output:0"
+ input: "InceptionV2/Conv2d_1a_7x7/biases/read"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "data_format"
+ value {
+ s: "NHWC"
+ }
+ }
+ }
+ node_def {
+ name: "InceptionV2/InceptionV2/Conv2d_1a_7x7/Relu"
+ op: "Relu"
+ input: "InceptionV2/InceptionV2/Conv2d_1a_7x7/BiasAdd:output:0"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ node_def {
+ name: "InceptionV2/InceptionV2/Conv2d_1a_7x7/separable_conv2d"
+ op: "Conv2D"
+ input: "InceptionV2/InceptionV2/Conv2d_1a_7x7/separable_conv2d/depthwise:output:0"
+ input: "InceptionV2/Conv2d_1a_7x7/pointwise_weights/read"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "data_format"
+ value {
+ s: "NHWC"
+ }
+ }
+ attr {
+ key: "dilations"
+ value {
+ list {
+ i: 1
+ i: 1
+ i: 1
+ i: 1
+ }
+ }
+ }
+ attr {
+ key: "padding"
+ value {
+ s: "VALID"
+ }
+ }
+ attr {
+ key: "strides"
+ value {
+ list {
+ i: 1
+ i: 1
+ i: 1
+ i: 1
+ }
+ }
+ }
+ attr {
+ key: "use_cudnn_on_gpu"
+ value {
+ b: true
+ }
+ }
+ }
+ node_def {
+ name: "InceptionV2/InceptionV2/Conv2d_1a_7x7/separable_conv2d/depthwise"
+ op: "DepthwiseConv2dNative"
+ input: "random_uniform"
+ input: "InceptionV2/Conv2d_1a_7x7/depthwise_weights/read"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "data_format"
+ value {
+ s: "NHWC"
+ }
+ }
+ attr {
+ key: "dilations"
+ value {
+ list {
+ i: 1
+ i: 1
+ i: 1
+ i: 1
+ }
+ }
+ }
+ attr {
+ key: "padding"
+ value {
+ s: "SAME"
+ }
+ }
+ attr {
+ key: "strides"
+ value {
+ list {
+ i: 1
+ i: 2
+ i: 2
+ i: 1
+ }
+ }
+ }
+ }
+ ret {
+ key: "InceptionV2/InceptionV2/Conv2d_1a_7x7/Relu"
+ value: "InceptionV2/InceptionV2/Conv2d_1a_7x7/Relu:activations:0"
+ }
+ ret {
+ key: "InceptionV2/InceptionV2/Conv2d_1a_7x7/separable_conv2d/depthwise"
+ value: "InceptionV2/InceptionV2/Conv2d_1a_7x7/separable_conv2d/depthwise:output:0"
+ }
+ attr {
+ key: "_noinline"
+ value {
+ b: true
+ }
+ }
+ }
+}
+versions {
+ producer: 26
+ min_consumer: 12
+}
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index 653b088b1d..bdeb5c66fc 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -161,8 +161,6 @@ bool IsExit(const NodeDef& node) {
return op == "Exit" || op == "RefExit";
}
-bool IsExp(const NodeDef& node) { return node.op() == "Exp"; }
-
bool IsFill(const NodeDef& node) { return node.op() == "Fill"; }
bool IsFloorDiv(const NodeDef& node) { return node.op() == "FloorDiv"; }
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
index 94439265c9..2de7d8cc9a 100644
--- a/tensorflow/core/grappler/op_types.h
+++ b/tensorflow/core/grappler/op_types.h
@@ -60,7 +60,6 @@ bool IsEluGrad(const NodeDef& node);
bool IsEnter(const NodeDef& node);
bool IsEqual(const NodeDef& node);
bool IsExit(const NodeDef& node);
-bool IsExp(const NodeDef& node);
bool IsFill(const NodeDef& node);
bool IsFloorDiv(const NodeDef& node);
bool IsFloorMod(const NodeDef& node);
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index b7369c7b4a..97862d1ed0 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -178,42 +178,6 @@ NodeDef* GetTailOfIdempotentChain(
is_idempotent_non_branching);
}
-// GetElementUnexhaustive tries to get the value of an element in a tensor and
-// turn it into complex128 type. It only check for a limited number of data
-// types, so it's unexhaustive.
-bool GetElementUnexhaustive(const Tensor& t, int i, const std::set<int>& dtypes,
- complex128* element) {
- if (dtypes.find(t.dtype()) == dtypes.end()) return false;
- switch (t.dtype()) {
- case DT_BFLOAT16:
- *element = complex128(t.flat<bfloat16>()(i));
- return true;
- case DT_HALF:
- *element = complex128(static_cast<double>(t.flat<Eigen::half>()(i)), 0);
- return true;
- case DT_INT32:
- *element = complex128(t.flat<int32>()(i));
- return true;
- case DT_INT64:
- *element = complex128(t.flat<int64>()(i));
- return true;
- case DT_FLOAT:
- *element = complex128(t.flat<float>()(i));
- return true;
- case DT_DOUBLE:
- *element = complex128(t.flat<double>()(i));
- return true;
- case DT_COMPLEX64:
- *element = complex128(t.flat<complex64>()(i));
- return true;
- case DT_COMPLEX128:
- *element = t.flat<complex128>()(i);
- return true;
- default:
- return false;
- }
-}
-
// Graph optimizer context extension specific to ArithmeticOptimizer.
struct ArithmeticOptimizerContext {
explicit ArithmeticOptimizerContext(SetVector<NodeDef*>* nodes_to_simplify)
@@ -2397,13 +2361,7 @@ class ConvertPowStage : public ArithmeticOptimizerStage {
complex128 prev, curr;
for (int i = 0; i < pow.NumElements(); ++i) {
- if (!GetElementUnexhaustive(pow, i,
- {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE,
- DT_COMPLEX64, DT_COMPLEX128},
- &curr)) {
- // input data type is not supported by Pow. Skip.
- return Status::OK();
- }
+ TF_RETURN_IF_ERROR(GetElement(pow, i, &curr));
if (i != 0 && curr != prev) {
// pow has different values on different elements. Skip.
return Status::OK();
@@ -2474,6 +2432,31 @@ class ConvertPowStage : public ArithmeticOptimizerStage {
}
private:
+ Status GetElement(const Tensor& t, int i, complex128* element) {
+ switch (t.dtype()) {
+ case DT_INT32:
+ *element = complex128(t.flat<int32>()(i));
+ return Status::OK();
+ case DT_INT64:
+ *element = complex128(t.flat<int64>()(i));
+ return Status::OK();
+ case DT_FLOAT:
+ *element = complex128(t.flat<float>()(i));
+ return Status::OK();
+ case DT_DOUBLE:
+ *element = complex128(t.flat<double>()(i));
+ return Status::OK();
+ case DT_COMPLEX64:
+ *element = complex128(t.flat<complex64>()(i));
+ return Status::OK();
+ case DT_COMPLEX128:
+ *element = t.flat<complex128>()(i);
+ return Status::OK();
+ default:
+ return errors::InvalidArgument("Invalid data type: ", t.dtype());
+ }
+ }
+
Status SetElementToOne(int i, Tensor* t) {
switch (t->dtype()) {
case DT_INT32:
@@ -2561,10 +2544,7 @@ class ConvertLog1pStage : public ArithmeticOptimizerStage {
}
complex128 element;
for (int k = 0; k < constant.NumElements(); ++k) {
- if (!GetElementUnexhaustive(constant, k,
- {DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE,
- DT_COMPLEX64, DT_COMPLEX128},
- &element)) {
+ if (!GetElement(constant, k, &element)) {
// input data type is not supported by log1p. Skip.
return Status::OK();
}
@@ -2589,81 +2569,30 @@ class ConvertLog1pStage : public ArithmeticOptimizerStage {
}
return Status::OK();
}
-};
-class ConvertExpm1Stage : public ArithmeticOptimizerStage {
- public:
- explicit ConvertExpm1Stage(const GraphOptimizerContext& ctx,
- const ArithmeticOptimizerContext& ctx_ext)
- : ArithmeticOptimizerStage("ConvertExpm1", ctx, ctx_ext) {}
- ~ConvertExpm1Stage() override = default;
-
- bool IsSupported(const NodeDef* node) const override { return IsExp(*node); }
-
- Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
- NodeDef* input;
- TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
- if (!IsSub(*input)) {
- return Status::OK();
- }
-
- if (ctx().graph_properties->GetInputProperties(input->name()).size() < 2) {
- return Status::OK();
- }
-
- const auto& t =
- ctx().graph_properties->GetInputProperties(input->name())[0];
- const auto& c =
- ctx().graph_properties->GetInputProperties(input->name())[1];
- for (int k = 0; k < c.shape().dim_size(); ++k) {
- // Skip if c shape is not fully determined.
- if (c.shape().dim(k).size() < 0) {
- return Status::OK();
- }
- }
- TensorShapeProto broadcast_shape;
- if (!ShapeAfterBroadcast(t.shape(), c.shape(), &broadcast_shape)) {
- return Status::OK();
- }
- if (!ShapesSymbolicallyEqual(t.shape(), broadcast_shape)) {
- // skip if the non-constant tensor doesn't have the same shape after
- // broadcast.
- return Status::OK();
- }
- if (TensorShape::IsValid(c.shape()) && c.has_value()) {
- Tensor constant(c.dtype(), c.shape());
- if (!constant.FromProto(c.value())) {
- return errors::InvalidArgument("Cannot parse tensor from proto: ",
- c.value().DebugString());
- }
- complex128 element;
- for (int k = 0; k < constant.NumElements(); ++k) {
- if (!GetElementUnexhaustive(constant, k,
- {DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE,
- DT_COMPLEX64, DT_COMPLEX128},
- &element)) {
- // input data type is not supported by expm1. Skip.
- return Status::OK();
- }
- if (element != complex128(1)) {
- // current element is not 1. Skip.
- return Status::OK();
- }
- }
- NodeDef *x, *y;
- TF_RETURN_IF_ERROR(GetInputNode(input->input(0), &x));
- TF_RETURN_IF_ERROR(GetInputNode(input->input(1), &y));
- node->set_op("Expm1");
- node->set_input(0, input->input(0));
- node->add_input(AsControlDependency(y->name()));
- ForwardControlDependencies(node, {input});
-
- AddToOptimizationQueue(node);
- AddToOptimizationQueue(input);
- AddToOptimizationQueue(x);
- AddToOptimizationQueue(y);
+ bool GetElement(const Tensor& t, int i, complex128* element) {
+ switch (t.dtype()) {
+ case DT_BFLOAT16:
+ *element = complex128(t.flat<bfloat16>()(i));
+ return true;
+ case DT_HALF:
+ *element = complex128(static_cast<double>(t.flat<Eigen::half>()(i)), 0);
+ return true;
+ case DT_FLOAT:
+ *element = complex128(t.flat<float>()(i));
+ return true;
+ case DT_DOUBLE:
+ *element = complex128(t.flat<double>()(i));
+ return true;
+ case DT_COMPLEX64:
+ *element = complex128(t.flat<complex64>()(i));
+ return true;
+ case DT_COMPLEX128:
+ *element = t.flat<complex128>()(i);
+ return true;
+ default:
+ return false;
}
- return Status::OK();
}
};
@@ -3165,8 +3094,6 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
pipeline.AddStage<ConvertLog1pStage>(ctx, ctx_ext);
if (options_.optimize_max_or_min_of_monotonic)
pipeline.AddStage<OptimizeMaxOrMinOfMonotonicStage>(ctx, ctx_ext);
- if (options_.convert_expm1)
- pipeline.AddStage<ConvertExpm1Stage>(ctx, ctx_ext);
if (options_.unary_ops_composition)
pipeline.AddStage<UnaryOpsComposition>(ctx, ctx_ext);
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
index 551c3652bf..00c02d19bd 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
@@ -77,7 +77,6 @@ class ArithmeticOptimizer : public GraphOptimizer {
bool simplify_aggregation = true;
bool convert_pow = true;
bool convert_log1p = true;
- bool convert_expm1 = true;
bool unary_ops_composition = true;
// Choose which arithmetic optimizer stages will be enabled for a given
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index 54fdc01adb..c387b00303 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -279,11 +279,6 @@ class ArithmeticOptimizerTest : public GrapplerTest {
optimizer->options_.optimize_max_or_min_of_monotonic = true;
}
- void EnableOnlyExpm1(ArithmeticOptimizer* optimizer) {
- DisableAllStages(optimizer);
- optimizer->options_.convert_expm1 = true;
- }
-
void EnableOnlyUnaryOpsComposition(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.unary_ops_composition = true;
@@ -2547,43 +2542,6 @@ TEST_F(ArithmeticOptimizerTest, Log1p) {
CompareGraphs(want, got);
}
-TEST_F(ArithmeticOptimizerTest, Expm1) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-
- auto x1 = ops::Const(s.WithOpName("x1"), {2.0f, 2.0f}, {1, 2});
- auto x2 = ops::Const(s.WithOpName("x2"), {1.0f, 1.0f}, {1, 2});
- auto x3 = ops::Const(s.WithOpName("x3"), {3.0f, 3.0f}, {1, 2});
- auto s12 = ops::Sub(s.WithOpName("s12").WithControlDependencies(x3), x1, x2);
- auto s23 = ops::Sub(s.WithOpName("s23"), x2, x3);
- Output out1 = ops::Exp(s.WithOpName("out1"), s12);
- Output out2 = ops::Exp(s.WithOpName("out2"), s23);
-
- GrapplerItem item;
- item.fetch = {"out1", "out2"};
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
- EXPECT_EQ(2, tensors_expected.size());
-
- GraphDef got;
- ArithmeticOptimizer optimizer;
- EnableOnlyExpm1(&optimizer);
- OptimizeAndPrune(&optimizer, &item, &got);
- auto tensors = EvaluateNodes(got, item.fetch);
- EXPECT_EQ(2, tensors.size());
-
- GraphDef want;
- AddNode("x1", "Const", {}, {}, &want);
- AddNode("x2", "Const", {}, {}, &want);
- AddNode("x3", "Const", {}, {}, &want);
- AddNode("s23", "Sub", {"x2", "x3"}, {}, &want);
- AddNode("out1", "Expm1",
- {"x1", AsControlDependency("x2"), AsControlDependency("x3")}, {},
- &want);
- AddNode("out2", "Exp", {"s23"}, {}, &want);
-
- CompareGraphs(want, got);
-}
-
TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_SimpleSwap) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD
index aad24e36e0..3cb9d4d61c 100644
--- a/tensorflow/core/grappler/optimizers/data/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/BUILD
@@ -4,6 +4,39 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("//tensorflow/core:platform/default/build_config.bzl", "tf_protos_all")
cc_library(
+ name = "function_rename",
+ srcs = ["function_rename.cc"],
+ hdrs = [
+ "function_rename.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_utils",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/grappler:graph_view",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:op_types",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/clusters:cluster",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ ] + tf_protos_all(),
+)
+
+tf_cc_test(
+ name = "function_rename_test",
+ srcs = ["function_rename_test.cc"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":function_rename",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core/grappler:grappler_item",
+ ] + tf_protos_all(),
+)
+
+cc_library(
name = "graph_utils",
srcs = ["graph_utils.cc"],
hdrs = [
@@ -139,6 +172,7 @@ cc_library(
name = "data",
visibility = ["//visibility:public"],
deps = [
+ ":function_rename",
":map_and_batch_fusion",
":noop_elimination",
":shuffle_and_repeat_fusion",
diff --git a/tensorflow/core/grappler/optimizers/data/function_rename.cc b/tensorflow/core/grappler/optimizers/data/function_rename.cc
new file mode 100644
index 0000000000..8cf044d1bd
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/function_rename.cc
@@ -0,0 +1,51 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/function_rename.h"
+
+#include "tensorflow/core/grappler/clusters/cluster.h"
+#include "tensorflow/core/grappler/graph_view.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+namespace grappler {
+
+Status FunctionRename::Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) {
+ *output = item.graph;
+ GraphView graph(output);
+ int n = output->mutable_library()->function_size();
+ for (int i = 0; i < n; ++i) {
+ FunctionDef* fn = output->mutable_library()->mutable_function(i);
+ fn->mutable_signature()->set_name(fn->signature().name() + "world");
+ }
+
+ return Status::OK();
+}
+
+void FunctionRename::Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output, double result) {
+ // no-op
+}
+
+REGISTER_GRAPH_OPTIMIZER_AS(FunctionRename, "_test_only_function_rename");
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/function_rename.h b/tensorflow/core/grappler/optimizers/data/function_rename.h
new file mode 100644
index 0000000000..23ad9470ff
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/function_rename.h
@@ -0,0 +1,46 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_RENAME_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_RENAME_H_
+
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
+
+namespace tensorflow {
+namespace grappler {
+
+class FunctionRename : public CustomGraphOptimizer {
+ public:
+ FunctionRename() = default;
+ ~FunctionRename() override = default;
+
+ string name() const override { return "_test_only_function_rename"; };
+
+ Status Init(
+ const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
+ return Status::OK();
+ }
+
+ Status Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) override;
+
+ void Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output, double result) override;
+};
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_RENAME_H_
diff --git a/tensorflow/core/grappler/optimizers/data/function_rename_test.cc b/tensorflow/core/grappler/optimizers/data/function_rename_test.cc
new file mode 100644
index 0000000000..56b8a960a7
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/function_rename_test.cc
@@ -0,0 +1,42 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/function_rename.h"
+
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+TEST(FunctionRenameTest, RenameFunction) {
+ GrapplerItem item;
+ GraphDef *graph = &item.graph;
+ FunctionDef *fn = graph->mutable_library()->add_function();
+ fn->mutable_signature()->set_name("hello");
+
+ FunctionRename optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+ EXPECT_EQ(output.library().function(0).signature().name(), "helloworld");
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 07360d594b..7599cf7db2 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -882,7 +882,6 @@ tf_kernel_library(
"tile_functor_gpu.cu.cc",
],
prefix = "tile_ops",
- textual_hdrs = ["tile_ops_gpu_impl.h"],
deps = ARRAY_DEPS,
)
@@ -2087,6 +2086,7 @@ IMAGE_DEPS = [
"//tensorflow/core:jpeg_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "//tensorflow/core:png_internal",
"//tensorflow/core:protos_all_cc",
]
@@ -2661,7 +2661,7 @@ tf_kernel_library(
tf_kernel_library(
name = "summary_image_op",
prefix = "summary_image_op",
- deps = LOGGING_DEPS,
+ deps = LOGGING_DEPS + ["//tensorflow/core:png_internal"],
)
tf_kernel_library(
@@ -2706,17 +2706,16 @@ cc_library(
],
)
-MANIP_DEPS = [
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:manip_ops_op_lib",
- "//third_party/eigen3",
-]
-
tf_kernel_library(
name = "roll_op",
prefix = "roll_op",
- deps = MANIP_DEPS,
+ deps = [
+ ":bounds_check",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:manip_ops_op_lib",
+ "//third_party/eigen3",
+ ],
)
tf_cc_test(
diff --git a/tensorflow/core/kernels/concat_op.cc b/tensorflow/core/kernels/concat_op.cc
index a87b63f913..902327aaea 100644
--- a/tensorflow/core/kernels/concat_op.cc
+++ b/tensorflow/core/kernels/concat_op.cc
@@ -113,7 +113,7 @@ class ConcatBaseOp : public OpKernel {
int64 output_concat_dim = 0;
const bool input_is_scalar = IsLegacyScalar(input_shape);
for (int i = 0; i < N; ++i) {
- const auto in = values[i];
+ const auto& in = values[i];
const bool in_is_scalar = IsLegacyScalar(in.shape());
OP_REQUIRES(
c, in.dims() == input_dims || (input_is_scalar && in_is_scalar),
diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD
index cda94c5538..e04fa20414 100644
--- a/tensorflow/core/kernels/data/BUILD
+++ b/tensorflow/core/kernels/data/BUILD
@@ -563,15 +563,6 @@ tf_kernel_library(
)
tf_kernel_library(
- name = "identity_dataset_op",
- srcs = ["identity_dataset_op.cc"],
- deps = [
- ":dataset",
- "//tensorflow/core:framework",
- ],
-)
-
-tf_kernel_library(
name = "optimize_dataset_op",
srcs = ["optimize_dataset_op.cc"],
deps = [
@@ -619,7 +610,6 @@ tf_kernel_library(
":generator_dataset_op",
":group_by_reducer_dataset_op",
":group_by_window_dataset_op",
- ":identity_dataset_op",
":interleave_dataset_op",
":iterator_ops",
":map_and_batch_dataset_op",
diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc
index ee58341cfd..82da385405 100644
--- a/tensorflow/core/kernels/data/captured_function.cc
+++ b/tensorflow/core/kernels/data/captured_function.cc
@@ -214,6 +214,9 @@ Status CapturedFunction::Run(IteratorContext* ctx, std::vector<Tensor>&& args,
});
f_opts.step_container = &step_container;
f_opts.runner = ctx->runner();
+ if (ctx->lib()->device()->device_type() != DEVICE_CPU) {
+ f_opts.create_rendezvous = true;
+ }
// TODO(mrry): Add cancellation manager support to IteratorContext
// so that we can cancel running map functions. The local
// cancellation manager here is created so that we can run kernels
@@ -248,6 +251,9 @@ Status CapturedFunction::RunWithBorrowedArgs(IteratorContext* ctx,
});
f_opts.step_container = &step_container;
f_opts.runner = ctx->runner();
+ if (ctx->lib()->device()->device_type() != DEVICE_CPU) {
+ f_opts.create_rendezvous = true;
+ }
// TODO(mrry): Add cancellation manager support to IteratorContext
// so that we can cancel running map functions. The local
// cancellation manager here is created so that we can run kernels
@@ -304,6 +310,9 @@ Status CapturedFunction::RunInstantiated(const std::vector<Tensor>& args,
});
f_opts.step_container = &step_container;
f_opts.runner = runner;
+ if (lib->device()->device_type() != DEVICE_CPU) {
+ f_opts.create_rendezvous = true;
+ }
// TODO(mrry): Add cancellation manager support to IteratorContext
// so that we can cancel running map functions. The local
// cancellation manager here is created so that we can run kernels
@@ -351,6 +360,9 @@ void CapturedFunction::RunAsync(IteratorContext* ctx,
});
f_opts.step_container = step_container;
f_opts.runner = ctx->runner();
+ if (ctx->lib()->device()->device_type() != DEVICE_CPU) {
+ f_opts.create_rendezvous = true;
+ }
// TODO(mrry): Add cancellation manager support to IteratorContext
// so that we can cancel running map functions. The local
// cancellation manager here is created so that we can run kernels
diff --git a/tensorflow/core/kernels/data/generator_dataset_op.cc b/tensorflow/core/kernels/data/generator_dataset_op.cc
index aae62ad2fe..0981e42ba1 100644
--- a/tensorflow/core/kernels/data/generator_dataset_op.cc
+++ b/tensorflow/core/kernels/data/generator_dataset_op.cc
@@ -197,6 +197,9 @@ class GeneratorDatasetOp : public DatasetOpKernel {
REGISTER_KERNEL_BUILDER(Name("GeneratorDataset").Device(DEVICE_CPU),
GeneratorDatasetOp);
+REGISTER_KERNEL_BUILDER(
+ Name("GeneratorDataset").Device(DEVICE_GPU).HostMemory("handle"),
+ GeneratorDatasetOp);
} // namespace
diff --git a/tensorflow/core/kernels/data/identity_dataset_op.cc b/tensorflow/core/kernels/data/identity_dataset_op.cc
deleted file mode 100644
index e28f188336..0000000000
--- a/tensorflow/core/kernels/data/identity_dataset_op.cc
+++ /dev/null
@@ -1,102 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include <map>
-
-#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/kernels/data/dataset.h"
-
-namespace tensorflow {
-namespace {
-
-// The purpose of identity dataset is to serve as a placeholder when performing
-// optimizations. It is not expected to be surfaced in the Python API.
-class IdentityDatasetOp : public UnaryDatasetOpKernel {
- public:
- explicit IdentityDatasetOp(OpKernelConstruction* ctx)
- : UnaryDatasetOpKernel(ctx) {
- OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
- }
-
- protected:
- void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
- DatasetBase** output) override {
- *output = new Dataset(ctx, input);
- }
-
- private:
- class Dataset : public GraphDatasetBase {
- public:
- Dataset(OpKernelContext* ctx, const DatasetBase* input)
- : GraphDatasetBase(ctx), input_(input) {
- input_->Ref();
- }
-
- ~Dataset() override { input_->Unref(); }
-
- std::unique_ptr<IteratorBase> MakeIteratorInternal(
- const string& prefix) const override {
- return std::unique_ptr<IteratorBase>(
- new Iterator({this, strings::StrCat(prefix, "::Identity")}));
- }
-
- const DataTypeVector& output_dtypes() const override {
- return input_->output_dtypes();
- }
-
- const std::vector<PartialTensorShape>& output_shapes() const override {
- return input_->output_shapes();
- }
-
- string DebugString() const override { return "IdentityDatasetOp::Dataset"; }
-
- protected:
- Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
- Node** output) const override {
- Node* input_graph_node = nullptr;
- TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node));
- TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, output));
- return Status::OK();
- }
-
- private:
- class Iterator : public DatasetIterator<Dataset> {
- public:
- explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params) {}
-
- Status Initialize(IteratorContext* ctx) override {
- return errors::Unimplemented(strings::StrCat(prefix(), "::Initialize"));
- }
-
- Status GetNextInternal(IteratorContext* ctx,
- std::vector<Tensor>* out_tensors,
- bool* end_of_sequence) override {
- return errors::Unimplemented(
- strings::StrCat(prefix(), "::GetNextInternal"));
- }
- };
-
- const DatasetBase* const input_;
- };
-
- DataTypeVector output_types_;
- std::vector<PartialTensorShape> output_shapes_;
-};
-
-REGISTER_KERNEL_BUILDER(Name("IdentityDataset").Device(DEVICE_CPU),
- IdentityDatasetOp);
-} // namespace
-} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc
index 8e327b239d..2a94a54f3d 100644
--- a/tensorflow/core/kernels/data/iterator_ops.cc
+++ b/tensorflow/core/kernels/data/iterator_ops.cc
@@ -1150,22 +1150,45 @@ class DeserializeIteratorOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("Iterator").Device(DEVICE_CPU), IteratorHandleOp);
+REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE_CPU),
+ IteratorHandleOp);
+REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE_GPU),
+ IteratorHandleOp);
REGISTER_KERNEL_BUILDER(Name("MakeIterator").Device(DEVICE_CPU),
MakeIteratorOp);
+REGISTER_KERNEL_BUILDER(
+ Name("MakeIterator").Device(DEVICE_GPU).HostMemory("dataset"),
+ MakeIteratorOp);
REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE_CPU),
AnonymousIteratorHandleOp);
+REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE_GPU),
+ AnonymousIteratorHandleOp);
REGISTER_KERNEL_BUILDER(Name("DatasetToSingleElement").Device(DEVICE_CPU),
ToSingleElementOp);
REGISTER_KERNEL_BUILDER(Name("OneShotIterator").Device(DEVICE_CPU),
OneShotIteratorOp);
REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_CPU),
IteratorGetNextOp);
+REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_GPU),
+ IteratorGetNextOp);
REGISTER_KERNEL_BUILDER(Name("IteratorGetNextSync").Device(DEVICE_CPU),
IteratorGetNextSyncOp);
+REGISTER_KERNEL_BUILDER(Name("IteratorGetNextSync").Device(DEVICE_GPU),
+ IteratorGetNextSyncOp);
REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle").Device(DEVICE_CPU),
IteratorToStringHandleOp);
+REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle")
+ .Device(DEVICE_GPU)
+ .HostMemory("string_handle"),
+ IteratorToStringHandleOp);
REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandle").Device(DEVICE_CPU),
IteratorFromStringHandleOp);
+REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandleV2").Device(DEVICE_CPU),
+ IteratorFromStringHandleOp);
+REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandleV2")
+ .Device(DEVICE_GPU)
+ .HostMemory("string_handle"),
+ IteratorFromStringHandleOp);
REGISTER_KERNEL_BUILDER(Name("SerializeIterator").Device(DEVICE_CPU),
SerializeIteratorOp);
REGISTER_KERNEL_BUILDER(Name("DeserializeIterator").Device(DEVICE_CPU),
diff --git a/tensorflow/core/kernels/data/optimize_dataset_op.cc b/tensorflow/core/kernels/data/optimize_dataset_op.cc
index 8965858e8d..276f5f89c8 100644
--- a/tensorflow/core/kernels/data/optimize_dataset_op.cc
+++ b/tensorflow/core/kernels/data/optimize_dataset_op.cc
@@ -54,8 +54,8 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
ctx, ParseVectorArgument<string>(ctx, "optimizations", &optimizations));
Dataset* dataset =
new Dataset(ctx, input, optimizations, output_types_, output_shapes_);
- core::ScopedUnref unref(dataset);
- OP_REQUIRES_OK(ctx, dataset->Optimize(ctx, output));
+ OP_REQUIRES_OK(ctx, dataset->Optimize(ctx));
+ *output = dataset;
}
private:
@@ -73,7 +73,10 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
input_->Ref();
}
- ~Dataset() override { input_->Unref(); }
+ ~Dataset() override {
+ input_->Unref();
+ optimized_input_->Unref();
+ }
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
@@ -81,7 +84,7 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
new Iterator({this, strings::StrCat(prefix, "::Optimize")}));
}
- Status Optimize(OpKernelContext* ctx, DatasetBase** output) {
+ Status Optimize(OpKernelContext* ctx) {
GraphDefBuilder b;
DatasetGraphDefBuilder db(&b);
Node* input_node = nullptr;
@@ -89,18 +92,20 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
string output_node = input_node->name();
GraphDef graph_def;
TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def));
+ VLOG(3) << "Before optimization: " << graph_def.DebugString();
TF_RETURN_IF_ERROR(ApplyOptimizations(ctx, &graph_def, &output_node));
-
+ VLOG(3) << "After optimization: " << graph_def.DebugString();
+ flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(),
+ graph_def.library()));
Graph graph(OpRegistry::Global());
TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
std::vector<Tensor> outputs;
- GraphRunner graph_runner(ctx->env());
- // Once rewrites that add/modify functions are introduced, we will need
- // persist the results in a function library runtime.
+ GraphRunner graph_runner(ctx->function_library()->device());
TF_RETURN_IF_ERROR(graph_runner.Run(&graph, ctx->function_library(), {},
{output_node}, &outputs));
- TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], output));
- (*output)->Ref();
+ TF_RETURN_IF_ERROR(
+ GetDatasetFromVariantTensor(outputs[0], &optimized_input_));
+ optimized_input_->Ref();
return Status::OK();
}
@@ -113,6 +118,18 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
string DebugString() const override { return "OptimizeDatasetOp::Dataset"; }
+ protected:
+ Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* input_graph_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node));
+ Node* optimizations_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddVector(optimizations_, &optimizations_node));
+ TF_RETURN_IF_ERROR(
+ b->AddDataset(this, {input_graph_node, optimizations_node}, output));
+ return Status::OK();
+ }
+
private:
class Iterator : public DatasetIterator<Dataset> {
public:
@@ -120,15 +137,38 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
- return errors::Unimplemented(strings::StrCat(prefix(), "::Initialize"));
+ return dataset()->optimized_input_->MakeIterator(ctx, prefix(),
+ &input_impl_);
}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
- return errors::Unimplemented(
- strings::StrCat(prefix(), "::GetNextInternal"));
+ IteratorContext::Params params;
+ params.env = ctx->env();
+ params.runner = *(ctx->runner());
+ params.stats_aggregator_getter = ctx->stats_aggregator_getter();
+ params.lib = ctx->lib();
+ params.function_library = dataset()->flib_def_;
+ params.allocator_getter = ctx->allocator_getter();
+ IteratorContext iter_ctx(params);
+ return input_impl_->GetNext(&iter_ctx, out_tensors, end_of_sequence);
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
+ return Status::OK();
}
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
+ return Status::OK();
+ }
+
+ private:
+ std::unique_ptr<IteratorBase> input_impl_;
};
Status ApplyOptimizations(OpKernelContext* ctx, GraphDef* graph_def,
@@ -136,16 +176,8 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
// Add a fake sink node to allow rewriting the actual sink node.
NodeDef* node = graph_def->mutable_node()->Add();
node->set_name("FakeSink");
- node->set_op("IdentityDataset");
+ node->set_op("SinkDataset");
node->add_input(*output_node);
- {
- grappler::GraphView graph(graph_def);
- NodeDef* sink = graph.GetNode(*output_node);
- (*node->mutable_attr())["output_shapes"] =
- sink->attr().at("output_shapes");
- (*node->mutable_attr())["output_types"] =
- sink->attr().at("output_types");
- }
// Create metagraph.
MetaGraphDef meta_graph_def;
@@ -162,10 +194,10 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
for (const string& optimization : optimizations_) {
rewriter_config.add_optimizers(optimization);
}
- // If no optimizations were specified, supply a non-existent optimization
- // to prevent Grappler from applying the default set of optimizations as
- // some of them do not work out of the box at the moment (e.g. because we
- // have no cost model for dataset ops).
+ // If no optimizations were specified, supply a non-existent
+ // optimization to prevent Grappler from applying the default set of
+ // optimizations as some of them do not work out of the box at the
+ // moment (e.g. because we have no cost model for dataset ops).
if (optimizations_.empty()) {
rewriter_config.add_optimizers("non-existent");
}
@@ -178,6 +210,12 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
tensorflow::grappler::VirtualCluster cluster(device_map);
// Run optimizer.
+ if (VLOG_IS_ON(2)) {
+ LOG(INFO) << "Performing the following optimizations:";
+ for (const string& optimization : optimizations_) {
+ LOG(INFO) << " " << optimization;
+ }
+ }
TF_RETURN_IF_ERROR(tensorflow::grappler::RunMetaOptimizer(
*grappler_item, rewriter_config, ctx->device(), &cluster, graph_def));
@@ -192,6 +230,8 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
+ DatasetBase* optimized_input_;
+ std::shared_ptr<FunctionLibraryDefinition> flib_def_;
const DatasetBase* input_;
const std::vector<string> optimizations_;
const DataTypeVector output_types_;
diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
index 2bafb985ef..cc16108dce 100644
--- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
@@ -357,7 +357,12 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
REGISTER_KERNEL_BUILDER(Name("PrefetchDataset").Device(DEVICE_CPU),
PrefetchDatasetOp);
-
+REGISTER_KERNEL_BUILDER(Name("PrefetchDataset")
+ .Device(DEVICE_GPU)
+ .HostMemory("buffer_size")
+ .HostMemory("input_dataset")
+ .HostMemory("handle"),
+ PrefetchDatasetOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/initializable_lookup_table.h b/tensorflow/core/kernels/initializable_lookup_table.h
index 990cbceac2..b4f81d9a70 100644
--- a/tensorflow/core/kernels/initializable_lookup_table.h
+++ b/tensorflow/core/kernels/initializable_lookup_table.h
@@ -51,7 +51,7 @@ class InitializableLookupTable : public LookupInterface {
"Insert not supported by InitializableLookupTable implementations");
}
- Status ExportValues(OpKernelContext* context) {
+ Status ExportValues(OpKernelContext* context) override {
return errors::Unimplemented(
"ExportValues not supported by InitializableLookupTable "
"implementations");
diff --git a/tensorflow/core/kernels/non_max_suppression_op.cc b/tensorflow/core/kernels/non_max_suppression_op.cc
index f08dd4f750..f59843a07a 100644
--- a/tensorflow/core/kernels/non_max_suppression_op.cc
+++ b/tensorflow/core/kernels/non_max_suppression_op.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/kernels/non_max_suppression_op.h"
+#include <functional>
#include <queue>
#include <vector>
@@ -38,9 +39,32 @@ namespace {
typedef Eigen::ThreadPoolDevice CPUDevice;
+static inline void CheckScoreSizes(OpKernelContext* context, int num_boxes,
+ const Tensor& scores) {
+ // The shape of 'scores' is [num_boxes]
+ OP_REQUIRES(context, scores.dims() == 1,
+ errors::InvalidArgument("scores must be 1-D",
+ scores.shape().DebugString()));
+ OP_REQUIRES(context, scores.dim_size(0) == num_boxes,
+ errors::InvalidArgument("scores has incompatible shape"));
+}
+
+static inline void ParseAndCheckOverlapSizes(OpKernelContext* context,
+ const Tensor& overlaps,
+ int* num_boxes) {
+ // the shape of 'overlaps' is [num_boxes, num_boxes]
+ OP_REQUIRES(context, overlaps.dims() == 2,
+ errors::InvalidArgument("overlaps must be 2-D",
+ overlaps.shape().DebugString()));
+
+ *num_boxes = overlaps.dim_size(0);
+ OP_REQUIRES(context, overlaps.dim_size(1) == *num_boxes,
+ errors::InvalidArgument("overlaps must be square",
+ overlaps.shape().DebugString()));
+}
+
static inline void ParseAndCheckBoxSizes(OpKernelContext* context,
- const Tensor& boxes,
- const Tensor& scores, int* num_boxes) {
+ const Tensor& boxes, int* num_boxes) {
// The shape of 'boxes' is [num_boxes, 4]
OP_REQUIRES(context, boxes.dims() == 2,
errors::InvalidArgument("boxes must be 2-D",
@@ -48,18 +72,12 @@ static inline void ParseAndCheckBoxSizes(OpKernelContext* context,
*num_boxes = boxes.dim_size(0);
OP_REQUIRES(context, boxes.dim_size(1) == 4,
errors::InvalidArgument("boxes must have 4 columns"));
-
- // The shape of 'scores' is [num_boxes]
- OP_REQUIRES(context, scores.dims() == 1,
- errors::InvalidArgument("scores must be 1-D",
- scores.shape().DebugString()));
- OP_REQUIRES(context, scores.dim_size(0) == *num_boxes,
- errors::InvalidArgument("scores has incompatible shape"));
}
// Return intersection-over-union overlap between boxes i and j
-static inline float IOU(typename TTypes<float, 2>::ConstTensor boxes, int i,
- int j) {
+static inline float IOUGreaterThanThreshold(
+ typename TTypes<float, 2>::ConstTensor boxes, int i, int j,
+ float iou_threshold) {
const float ymin_i = std::min<float>(boxes(i, 0), boxes(i, 2));
const float xmin_i = std::min<float>(boxes(i, 1), boxes(i, 3));
const float ymax_i = std::max<float>(boxes(i, 0), boxes(i, 2));
@@ -78,24 +96,36 @@ static inline float IOU(typename TTypes<float, 2>::ConstTensor boxes, int i,
const float intersection_area =
std::max<float>(intersection_ymax - intersection_ymin, 0.0) *
std::max<float>(intersection_xmax - intersection_xmin, 0.0);
- return intersection_area / (area_i + area_j - intersection_area);
+ const float iou = intersection_area / (area_i + area_j - intersection_area);
+ return iou > iou_threshold;
}
-void DoNonMaxSuppressionOp(OpKernelContext* context, const Tensor& boxes,
- const Tensor& scores, const Tensor& max_output_size,
- const float iou_threshold,
- const float score_threshold) {
- OP_REQUIRES(context, iou_threshold >= 0 && iou_threshold <= 1,
- errors::InvalidArgument("iou_threshold must be in [0, 1]"));
-
- int num_boxes = 0;
- ParseAndCheckBoxSizes(context, boxes, scores, &num_boxes);
- if (!context->status().ok()) {
- return;
- }
+static inline bool OverlapsGreaterThanThreshold(
+ typename TTypes<float, 2>::ConstTensor overlaps, int i, int j,
+ float overlap_threshold) {
+ return overlaps(i, j) > overlap_threshold;
+}
+
+static inline std::function<bool(int, int)> CreateIOUSuppressCheckFn(
+ const Tensor& boxes, float threshold) {
+ typename TTypes<float, 2>::ConstTensor boxes_data = boxes.tensor<float, 2>();
+ return std::bind(&IOUGreaterThanThreshold, boxes_data, std::placeholders::_1,
+ std::placeholders::_2, threshold);
+}
+
+static inline std::function<bool(int, int)> CreateOverlapsSuppressCheckFn(
+ const Tensor& overlaps, float threshold) {
+ typename TTypes<float, 2>::ConstTensor overlaps_data =
+ overlaps.tensor<float, 2>();
+ return std::bind(&OverlapsGreaterThanThreshold, overlaps_data,
+ std::placeholders::_1, std::placeholders::_2, threshold);
+}
+void DoNonMaxSuppressionOp(OpKernelContext* context, const Tensor& scores,
+ int num_boxes, const Tensor& max_output_size,
+ const float score_threshold,
+ std::function<bool(int, int)> suppress_check_fn) {
const int output_size = std::min(max_output_size.scalar<int>()(), num_boxes);
- TTypes<float, 2>::ConstTensor boxes_data = boxes.tensor<float, 2>();
std::vector<float> scores_data(num_boxes);
std::copy_n(scores.flat<float>().data(), num_boxes, scores_data.begin());
@@ -120,11 +150,9 @@ void DoNonMaxSuppressionOp(OpKernelContext* context, const Tensor& boxes,
std::vector<int> selected;
std::vector<float> selected_scores;
Candidate next_candidate;
- float iou, original_score;
while (selected.size() < output_size && !candidate_priority_queue.empty()) {
next_candidate = candidate_priority_queue.top();
- original_score = next_candidate.score;
candidate_priority_queue.pop();
// Overlapping boxes are likely to have similar scores,
@@ -132,8 +160,10 @@ void DoNonMaxSuppressionOp(OpKernelContext* context, const Tensor& boxes,
// in order to see if `next_candidate` should be suppressed.
bool should_select = true;
for (int j = selected.size() - 1; j >= 0; --j) {
- iou = IOU(boxes_data, next_candidate.box_index, selected[j]);
- if (iou > iou_threshold) should_select = false;
+ if (suppress_check_fn(next_candidate.box_index, selected[j])) {
+ should_select = false;
+ break;
+ }
}
if (should_select) {
@@ -173,9 +203,19 @@ class NonMaxSuppressionOp : public OpKernel {
errors::InvalidArgument("max_output_size must be 0-D, got shape ",
max_output_size.shape().DebugString()));
+ OP_REQUIRES(context, iou_threshold_ >= 0 && iou_threshold_ <= 1,
+ errors::InvalidArgument("iou_threshold must be in [0, 1]"));
+ int num_boxes = 0;
+ ParseAndCheckBoxSizes(context, boxes, &num_boxes);
+ CheckScoreSizes(context, num_boxes, scores);
+ if (!context->status().ok()) {
+ return;
+ }
+ auto suppress_check_fn = CreateIOUSuppressCheckFn(boxes, iou_threshold_);
+
const float score_threshold_val = std::numeric_limits<float>::lowest();
- DoNonMaxSuppressionOp(context, boxes, scores, max_output_size,
- iou_threshold_, score_threshold_val);
+ DoNonMaxSuppressionOp(context, scores, num_boxes, max_output_size,
+ score_threshold_val, suppress_check_fn);
}
private:
@@ -206,9 +246,19 @@ class NonMaxSuppressionV2Op : public OpKernel {
iou_threshold.shape().DebugString()));
const float iou_threshold_val = iou_threshold.scalar<float>()();
+ OP_REQUIRES(context, iou_threshold_val >= 0 && iou_threshold_val <= 1,
+ errors::InvalidArgument("iou_threshold must be in [0, 1]"));
+ int num_boxes = 0;
+ ParseAndCheckBoxSizes(context, boxes, &num_boxes);
+ CheckScoreSizes(context, num_boxes, scores);
+ if (!context->status().ok()) {
+ return;
+ }
+ auto suppress_check_fn = CreateIOUSuppressCheckFn(boxes, iou_threshold_val);
+
const float score_threshold_val = std::numeric_limits<float>::lowest();
- DoNonMaxSuppressionOp(context, boxes, scores, max_output_size,
- iou_threshold_val, score_threshold_val);
+ DoNonMaxSuppressionOp(context, scores, num_boxes, max_output_size,
+ score_threshold_val, suppress_check_fn);
}
};
@@ -244,8 +294,65 @@ class NonMaxSuppressionV3Op : public OpKernel {
score_threshold.shape().DebugString()));
const float score_threshold_val = score_threshold.scalar<float>()();
- DoNonMaxSuppressionOp(context, boxes, scores, max_output_size,
- iou_threshold_val, score_threshold_val);
+ OP_REQUIRES(context, iou_threshold_val >= 0 && iou_threshold_val <= 1,
+ errors::InvalidArgument("iou_threshold must be in [0, 1]"));
+ int num_boxes = 0;
+ ParseAndCheckBoxSizes(context, boxes, &num_boxes);
+ CheckScoreSizes(context, num_boxes, scores);
+ if (!context->status().ok()) {
+ return;
+ }
+ auto suppress_check_fn = CreateIOUSuppressCheckFn(boxes, iou_threshold_val);
+
+ DoNonMaxSuppressionOp(context, scores, num_boxes, max_output_size,
+ score_threshold_val, suppress_check_fn);
+ }
+};
+
+template <typename Device>
+class NonMaxSuppressionWithOverlapsOp : public OpKernel {
+ public:
+ explicit NonMaxSuppressionWithOverlapsOp(OpKernelConstruction* context)
+ : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ // overlaps: [num_boxes, num_boxes]
+ const Tensor& overlaps = context->input(0);
+ // scores: [num_boxes]
+ const Tensor& scores = context->input(1);
+ // max_output_size: scalar
+ const Tensor& max_output_size = context->input(2);
+ OP_REQUIRES(
+ context, TensorShapeUtils::IsScalar(max_output_size.shape()),
+ errors::InvalidArgument("max_output_size must be 0-D, got shape ",
+ max_output_size.shape().DebugString()));
+ // overlap_threshold: scalar
+ const Tensor& overlap_threshold = context->input(3);
+ OP_REQUIRES(
+ context, TensorShapeUtils::IsScalar(overlap_threshold.shape()),
+ errors::InvalidArgument("overlap_threshold must be 0-D, got shape ",
+ overlap_threshold.shape().DebugString()));
+ const float overlap_threshold_val = overlap_threshold.scalar<float>()();
+
+ // score_threshold: scalar
+ const Tensor& score_threshold = context->input(4);
+ OP_REQUIRES(
+ context, TensorShapeUtils::IsScalar(score_threshold.shape()),
+ errors::InvalidArgument("score_threshold must be 0-D, got shape ",
+ score_threshold.shape().DebugString()));
+ const float score_threshold_val = score_threshold.scalar<float>()();
+
+ int num_boxes = 0;
+ ParseAndCheckOverlapSizes(context, overlaps, &num_boxes);
+ CheckScoreSizes(context, num_boxes, scores);
+ if (!context->status().ok()) {
+ return;
+ }
+ auto suppress_check_fn =
+ CreateOverlapsSuppressCheckFn(overlaps, overlap_threshold_val);
+
+ DoNonMaxSuppressionOp(context, scores, num_boxes, max_output_size,
+ score_threshold_val, suppress_check_fn);
}
};
@@ -258,4 +365,8 @@ REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV2").Device(DEVICE_CPU),
REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV3").Device(DEVICE_CPU),
NonMaxSuppressionV3Op<CPUDevice>);
+REGISTER_KERNEL_BUILDER(
+ Name("NonMaxSuppressionWithOverlaps").Device(DEVICE_CPU),
+ NonMaxSuppressionWithOverlapsOp<CPUDevice>);
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/non_max_suppression_op_test.cc b/tensorflow/core/kernels/non_max_suppression_op_test.cc
index ed7db313bd..055161a35f 100644
--- a/tensorflow/core/kernels/non_max_suppression_op_test.cc
+++ b/tensorflow/core/kernels/non_max_suppression_op_test.cc
@@ -569,4 +569,241 @@ TEST_F(NonMaxSuppressionV3OpTest, TestEmptyInput) {
test::ExpectTensorEqual<int>(expected, *GetOutput(0));
}
+//
+// NonMaxSuppressionWithOverlapsOp Tests
+//
+
+class NonMaxSuppressionWithOverlapsOpTest : public OpsTestBase {
+ protected:
+ void MakeOp() {
+ TF_EXPECT_OK(NodeDefBuilder("non_max_suppression_op",
+ "NonMaxSuppressionWithOverlaps")
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DT_INT32))
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DT_FLOAT))
+ .Finalize(node_def()));
+ TF_EXPECT_OK(InitOp());
+ }
+
+ void AddIoUInput(const std::vector<float>& boxes) {
+ ASSERT_EQ((boxes.size() % 4), 0);
+ size_t num_boxes = boxes.size() / 4;
+ std::vector<float> iou_overlaps(num_boxes * num_boxes);
+
+ // compute the pairwise IoU overlaps
+ auto corner_access = [&boxes](size_t box_idx, size_t corner_idx) {
+ return boxes[box_idx * 4 + corner_idx];
+ };
+ for (size_t i = 0; i < num_boxes; ++i) {
+ for (size_t j = 0; j < num_boxes; ++j) {
+ const float ymin_i =
+ std::min<float>(corner_access(i, 0), corner_access(i, 2));
+ const float xmin_i =
+ std::min<float>(corner_access(i, 1), corner_access(i, 3));
+ const float ymax_i =
+ std::max<float>(corner_access(i, 0), corner_access(i, 2));
+ const float xmax_i =
+ std::max<float>(corner_access(i, 1), corner_access(i, 3));
+ const float ymin_j =
+ std::min<float>(corner_access(j, 0), corner_access(j, 2));
+ const float xmin_j =
+ std::min<float>(corner_access(j, 1), corner_access(j, 3));
+ const float ymax_j =
+ std::max<float>(corner_access(j, 0), corner_access(j, 2));
+ const float xmax_j =
+ std::max<float>(corner_access(j, 1), corner_access(j, 3));
+ const float area_i = (ymax_i - ymin_i) * (xmax_i - xmin_i);
+ const float area_j = (ymax_j - ymin_j) * (xmax_j - xmin_j);
+
+ float iou;
+ if (area_i <= 0 || area_j <= 0) {
+ iou = 0.0;
+ } else {
+ const float intersection_ymin = std::max<float>(ymin_i, ymin_j);
+ const float intersection_xmin = std::max<float>(xmin_i, xmin_j);
+ const float intersection_ymax = std::min<float>(ymax_i, ymax_j);
+ const float intersection_xmax = std::min<float>(xmax_i, xmax_j);
+ const float intersection_area =
+ std::max<float>(intersection_ymax - intersection_ymin, 0.0) *
+ std::max<float>(intersection_xmax - intersection_xmin, 0.0);
+ iou = intersection_area / (area_i + area_j - intersection_area);
+ }
+ iou_overlaps[i * num_boxes + j] = iou;
+ }
+ }
+
+ AddInputFromArray<float>(TensorShape({static_cast<signed>(num_boxes),
+ static_cast<signed>(num_boxes)}),
+ iou_overlaps);
+ }
+};
+
+TEST_F(NonMaxSuppressionWithOverlapsOpTest, TestSelectFromThreeClusters) {
+ MakeOp();
+ AddIoUInput({0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f,
+ 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101});
+ AddInputFromArray<float>(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f});
+ AddInputFromArray<int>(TensorShape({}), {3});
+ AddInputFromArray<float>(TensorShape({}), {.5f});
+ AddInputFromArray<float>(TensorShape({}), {0.0f});
+ TF_ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_INT32, TensorShape({3}));
+ test::FillValues<int>(&expected, {3, 0, 5});
+ test::ExpectTensorEqual<int>(expected, *GetOutput(0));
+}
+
+TEST_F(NonMaxSuppressionWithOverlapsOpTest,
+ TestSelectFromThreeClustersFlippedCoordinates) {
+ MakeOp();
+ AddIoUInput({1, 1, 0, 0, 0, 0.1f, 1, 1.1f, 0, .9f, 1, -0.1f,
+ 0, 10, 1, 11, 1, 10.1f, 0, 11.1f, 1, 101, 0, 100});
+ AddInputFromArray<float>(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f});
+ AddInputFromArray<int>(TensorShape({}), {3});
+ AddInputFromArray<float>(TensorShape({}), {.5f});
+ AddInputFromArray<float>(TensorShape({}), {0.0f});
+ TF_ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_INT32, TensorShape({3}));
+ test::FillValues<int>(&expected, {3, 0, 5});
+ test::ExpectTensorEqual<int>(expected, *GetOutput(0));
+}
+
+TEST_F(NonMaxSuppressionWithOverlapsOpTest,
+ TestSelectAtMostTwoBoxesFromThreeClusters) {
+ MakeOp();
+ AddIoUInput({0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f,
+ 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101});
+ AddInputFromArray<float>(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f});
+ AddInputFromArray<int>(TensorShape({}), {2});
+ AddInputFromArray<float>(TensorShape({}), {.5f});
+ AddInputFromArray<float>(TensorShape({}), {0.0f});
+ TF_ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_INT32, TensorShape({2}));
+ test::FillValues<int>(&expected, {3, 0});
+ test::ExpectTensorEqual<int>(expected, *GetOutput(0));
+}
+
+TEST_F(NonMaxSuppressionWithOverlapsOpTest,
+ TestSelectAtMostThirtyBoxesFromThreeClusters) {
+ MakeOp();
+ AddIoUInput({0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f,
+ 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101});
+ AddInputFromArray<float>(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f});
+ AddInputFromArray<int>(TensorShape({}), {30});
+ AddInputFromArray<float>(TensorShape({}), {.5f});
+ AddInputFromArray<float>(TensorShape({}), {0.0f});
+ TF_ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_INT32, TensorShape({3}));
+ test::FillValues<int>(&expected, {3, 0, 5});
+ test::ExpectTensorEqual<int>(expected, *GetOutput(0));
+}
+
+TEST_F(NonMaxSuppressionWithOverlapsOpTest, TestSelectSingleBox) {
+ MakeOp();
+ AddIoUInput({0, 0, 1, 1});
+ AddInputFromArray<float>(TensorShape({1}), {.9f});
+ AddInputFromArray<int>(TensorShape({}), {3});
+ AddInputFromArray<float>(TensorShape({}), {.5f});
+ AddInputFromArray<float>(TensorShape({}), {0.0f});
+ TF_ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_INT32, TensorShape({1}));
+ test::FillValues<int>(&expected, {0});
+ test::ExpectTensorEqual<int>(expected, *GetOutput(0));
+}
+
+TEST_F(NonMaxSuppressionWithOverlapsOpTest, TestSelectFromTenIdenticalBoxes) {
+ MakeOp();
+
+ int num_boxes = 10;
+ std::vector<float> corners(num_boxes * 4);
+ std::vector<float> scores(num_boxes);
+ for (int i = 0; i < num_boxes; ++i) {
+ corners[i * 4 + 0] = 0;
+ corners[i * 4 + 1] = 0;
+ corners[i * 4 + 2] = 1;
+ corners[i * 4 + 3] = 1;
+ scores[i] = .9;
+ }
+ AddIoUInput(corners);
+ AddInputFromArray<float>(TensorShape({num_boxes}), scores);
+ AddInputFromArray<int>(TensorShape({}), {3});
+ AddInputFromArray<float>(TensorShape({}), {.5f});
+ AddInputFromArray<float>(TensorShape({}), {0.0f});
+ TF_ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_INT32, TensorShape({1}));
+ test::FillValues<int>(&expected, {0});
+ test::ExpectTensorEqual<int>(expected, *GetOutput(0));
+}
+
+TEST_F(NonMaxSuppressionWithOverlapsOpTest, TestInconsistentBoxAndScoreShapes) {
+ MakeOp();
+ AddIoUInput({0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f,
+ 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101});
+ AddInputFromArray<float>(TensorShape({5}), {.9f, .75f, .6f, .95f, .5f});
+ AddInputFromArray<int>(TensorShape({}), {30});
+ AddInputFromArray<float>(TensorShape({}), {.5f});
+ AddInputFromArray<float>(TensorShape({}), {0.0f});
+ Status s = RunOpKernel();
+
+ ASSERT_FALSE(s.ok());
+ EXPECT_TRUE(
+ str_util::StrContains(s.ToString(), "scores has incompatible shape"))
+ << s;
+}
+
+TEST_F(NonMaxSuppressionWithOverlapsOpTest, TestInvalidOverlapsShape) {
+ MakeOp();
+ AddInputFromArray<float>(TensorShape({2, 3}), {0, 0, 0, 0, 0, 0});
+ AddInputFromArray<float>(TensorShape({2}), {0.5f, 0.5f});
+ AddInputFromArray<int>(TensorShape({}), {30});
+ AddInputFromArray<float>(TensorShape({}), {0.f});
+ AddInputFromArray<float>(TensorShape({}), {0.0f});
+ Status s = RunOpKernel();
+
+ ASSERT_FALSE(s.ok());
+ EXPECT_TRUE(str_util::StrContains(s.ToString(), "overlaps must be square"))
+ << s;
+}
+
+TEST_F(NonMaxSuppressionWithOverlapsOpTest, TestThresholdGreaterOne) {
+ MakeOp();
+ AddIoUInput({0, 0, 1, 1});
+ AddInputFromArray<float>(TensorShape({1}), {.9f});
+ AddInputFromArray<int>(TensorShape({}), {3});
+ AddInputFromArray<float>(TensorShape({}), {1.2f});
+ AddInputFromArray<float>(TensorShape({}), {0.0f});
+ TF_ASSERT_OK(RunOpKernel());
+}
+
+TEST_F(NonMaxSuppressionWithOverlapsOpTest, TestThresholdSmallerZero) {
+ MakeOp();
+ AddIoUInput({0, 0, 1, 1});
+ AddInputFromArray<float>(TensorShape({1}), {.9f});
+ AddInputFromArray<int>(TensorShape({}), {3});
+ AddInputFromArray<float>(TensorShape({}), {-0.2f});
+ AddInputFromArray<float>(TensorShape({}), {0.0f});
+ TF_ASSERT_OK(RunOpKernel());
+}
+
+TEST_F(NonMaxSuppressionWithOverlapsOpTest, TestEmptyInput) {
+ MakeOp();
+ AddIoUInput({});
+ AddInputFromArray<float>(TensorShape({0}), {});
+ AddInputFromArray<int>(TensorShape({}), {30});
+ AddInputFromArray<float>(TensorShape({}), {.5f});
+ AddInputFromArray<float>(TensorShape({}), {0.0f});
+ TF_ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_INT32, TensorShape({0}));
+ test::FillValues<int>(&expected, {});
+ test::ExpectTensorEqual<int>(expected, *GetOutput(0));
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/partitioned_function_ops.cc b/tensorflow/core/kernels/partitioned_function_ops.cc
index 71c1b56fbd..b5c6ba1da3 100644
--- a/tensorflow/core/kernels/partitioned_function_ops.cc
+++ b/tensorflow/core/kernels/partitioned_function_ops.cc
@@ -156,14 +156,8 @@ class PartitionedCallOp : public AsyncOpKernel {
// Pins each arg that emits a `DT_RESOURCE` tensor to the device on which the
// corresponding resource lives. This ensures that the Placer assigns ops that
- // access these resources to the appropriate devices. This method throws an
- // error if any two resource inputs live on different devices.
- //
- // TODO(akshayka): Remove the single-device constraint once we have a
- // mechanism for telling the Placer that an op supports heterogeneous
- // devices among its input resources.
+ // access these resources to the appropriate devices.
Status PinResourceArgs(Graph* graph, const OpInputList& args) {
- string device;
for (Node* node : graph->op_nodes()) {
string node_type = node->type_string();
if (node_type == FunctionLibraryDefinition::kArgOp) {
@@ -174,15 +168,7 @@ class PartitionedCallOp : public AsyncOpKernel {
DataType dtype = attr_value->type();
if (dtype == DT_RESOURCE) {
ResourceHandle handle = args[index].flat<ResourceHandle>()(0);
- const string& handle_device = handle.device();
- if (device.empty()) {
- device = handle_device;
- } else if (device != handle_device) {
- return errors::Internal(
- "Resources must reside on a single device; observed devices ",
- device, " and ", handle_device);
- }
- node->set_assigned_device_name(handle_device);
+ node->set_assigned_device_name(handle.device());
}
}
}
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc
index af921e4815..c5292e1ae1 100644
--- a/tensorflow/core/kernels/resource_variable_ops.cc
+++ b/tensorflow/core/kernels/resource_variable_ops.cc
@@ -174,25 +174,20 @@ REGISTER_KERNEL_BUILDER(Name("VariableShape")
#endif // GOOGLE_CUDA
-class DestroyResourceOp : public OpKernel {
- public:
- explicit DestroyResourceOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
- OP_REQUIRES_OK(ctx,
- ctx->GetAttr("ignore_lookup_error", &ignore_lookup_error_));
- }
+DestroyResourceOp::DestroyResourceOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx,
+ ctx->GetAttr("ignore_lookup_error", &ignore_lookup_error_));
+}
- void Compute(OpKernelContext* ctx) override {
- const ResourceHandle& p = HandleFromInput(ctx, 0);
- Status status = DeleteResource(ctx, p);
- if (ignore_lookup_error_ && errors::IsNotFound(status)) {
- return;
- }
- OP_REQUIRES_OK(ctx, status);
+void DestroyResourceOp::Compute(OpKernelContext* ctx) {
+ const ResourceHandle& p = HandleFromInput(ctx, 0);
+ Status status = DeleteResource(ctx, p);
+ if (ignore_lookup_error_ && errors::IsNotFound(status)) {
+ return;
}
-
- private:
- bool ignore_lookup_error_;
-};
+ OP_REQUIRES_OK(ctx, status);
+}
REGISTER_KERNEL_BUILDER(Name("DestroyResourceOp").Device(DEVICE_CPU),
DestroyResourceOp);
diff --git a/tensorflow/core/kernels/resource_variable_ops.h b/tensorflow/core/kernels/resource_variable_ops.h
index 8cae5d21f0..9b60106f13 100644
--- a/tensorflow/core/kernels/resource_variable_ops.h
+++ b/tensorflow/core/kernels/resource_variable_ops.h
@@ -28,6 +28,15 @@ class ReadVariableOp : public OpKernel {
DataType dtype_;
};
+class DestroyResourceOp : public OpKernel {
+ public:
+ explicit DestroyResourceOp(OpKernelConstruction* ctx);
+ void Compute(OpKernelContext* ctx) override;
+
+ private:
+ bool ignore_lookup_error_;
+};
+
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_RESOURCE_VARIABLE_OPS_H_
diff --git a/tensorflow/core/kernels/roll_op.cc b/tensorflow/core/kernels/roll_op.cc
index 722116f86f..efa30438d9 100644
--- a/tensorflow/core/kernels/roll_op.cc
+++ b/tensorflow/core/kernels/roll_op.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/register_types_traits.h"
#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/work_sharder.h"
@@ -258,7 +259,7 @@ class RollOp : public OpKernel {
if (axis < 0) {
axis += num_dims;
}
- OP_REQUIRES(context, 0 <= axis && axis < num_dims,
+ OP_REQUIRES(context, FastBoundsCheck(axis, num_dims),
errors::InvalidArgument("axis ", axis, " is out of range"));
const int ds = std::max<int>(static_cast<int>(input.dim_size(axis)), 1);
const int sum = shift_mod_sum[axis] + static_cast<int>(shift_flat(i));
diff --git a/tensorflow/core/kernels/segment_reduction_ops.h b/tensorflow/core/kernels/segment_reduction_ops.h
index 2da83a0288..d28e35157b 100644
--- a/tensorflow/core/kernels/segment_reduction_ops.h
+++ b/tensorflow/core/kernels/segment_reduction_ops.h
@@ -13,16 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_
-
-
-// This file requires the following include because it uses CudaAtomicMax:
-// #include "tensorflow/core/util/cuda_kernel_helper.h"
-
-// Unfortunately we can't add the #include, since it breaks compilation for
-// non-GPU targets. This only breaks in clang, because it's more strict for
-// template code and CudaAtomicMax is used in template context.
+#ifndef TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_
+#define TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_
// This file requires the following include because it uses CudaAtomicMax:
// #include "tensorflow/core/util/cuda_kernel_helper.h"
diff --git a/tensorflow/core/kernels/sendrecv_ops.cc b/tensorflow/core/kernels/sendrecv_ops.cc
index 2f87057f4e..6521dcf932 100644
--- a/tensorflow/core/kernels/sendrecv_ops.cc
+++ b/tensorflow/core/kernels/sendrecv_ops.cc
@@ -160,7 +160,6 @@ Rendezvous::DoneCallback make_recv_callback(OpKernelContext* ctx,
if (!is_dead) {
ctx->set_output(0, val);
}
- *ctx->is_output_dead() = is_dead;
}
done();
},
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index 7abf382262..be72ee8066 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -7681,66 +7681,6 @@ op {
}
}
op {
- name: "AvgPool"
- input_arg {
- name: "value"
- type_attr: "T"
- }
- output_arg {
- name: "output"
- type_attr: "T"
- }
- attr {
- name: "ksize"
- type: "list(int)"
- has_minimum: true
- minimum: 4
- }
- attr {
- name: "strides"
- type: "list(int)"
- has_minimum: true
- minimum: 4
- }
- attr {
- name: "padding"
- type: "string"
- allowed_values {
- list {
- s: "SAME"
- s: "VALID"
- }
- }
- }
- attr {
- name: "data_format"
- type: "string"
- default_value {
- s: "NHWC"
- }
- allowed_values {
- list {
- s: "NHWC"
- s: "NCHW"
- s: "HWNC"
- s: "HWCN"
- }
- }
- }
- attr {
- name: "T"
- type: "type"
- allowed_values {
- list {
- type: DT_HALF
- type: DT_BFLOAT16
- type: DT_FLOAT
- type: DT_DOUBLE
- }
- }
- }
-}
-op {
name: "AvgPool3D"
input_arg {
name: "input"
@@ -8430,70 +8370,6 @@ op {
}
}
op {
- name: "AvgPoolGrad"
- input_arg {
- name: "orig_input_shape"
- type: DT_INT32
- }
- input_arg {
- name: "grad"
- type_attr: "T"
- }
- output_arg {
- name: "output"
- type_attr: "T"
- }
- attr {
- name: "ksize"
- type: "list(int)"
- has_minimum: true
- minimum: 4
- }
- attr {
- name: "strides"
- type: "list(int)"
- has_minimum: true
- minimum: 4
- }
- attr {
- name: "padding"
- type: "string"
- allowed_values {
- list {
- s: "SAME"
- s: "VALID"
- }
- }
- }
- attr {
- name: "data_format"
- type: "string"
- default_value {
- s: "NHWC"
- }
- allowed_values {
- list {
- s: "NHWC"
- s: "NCHW"
- s: "HWNC"
- s: "HWCN"
- }
- }
- }
- attr {
- name: "T"
- type: "type"
- allowed_values {
- list {
- type: DT_HALF
- type: DT_BFLOAT16
- type: DT_FLOAT
- type: DT_DOUBLE
- }
- }
- }
-}
-op {
name: "Barrier"
output_arg {
name: "handle"
@@ -10555,61 +10431,6 @@ op {
}
}
op {
- name: "BiasAdd"
- input_arg {
- name: "value"
- type_attr: "T"
- }
- input_arg {
- name: "bias"
- type_attr: "T"
- }
- output_arg {
- name: "output"
- type_attr: "T"
- }
- attr {
- name: "T"
- type: "type"
- allowed_values {
- list {
- type: DT_FLOAT
- type: DT_DOUBLE
- type: DT_INT32
- type: DT_UINT8
- type: DT_INT16
- type: DT_INT8
- type: DT_COMPLEX64
- type: DT_INT64
- type: DT_QINT8
- type: DT_QUINT8
- type: DT_QINT32
- type: DT_BFLOAT16
- type: DT_UINT16
- type: DT_COMPLEX128
- type: DT_HALF
- type: DT_UINT32
- type: DT_UINT64
- }
- }
- }
- attr {
- name: "data_format"
- type: "string"
- default_value {
- s: "NHWC"
- }
- allowed_values {
- list {
- s: "NHWC"
- s: "NCHW"
- s: "HWNC"
- s: "HWCN"
- }
- }
- }
-}
-op {
name: "BiasAddGrad"
input_arg {
name: "out_backprop"
@@ -10802,57 +10623,6 @@ op {
}
}
op {
- name: "BiasAddGrad"
- input_arg {
- name: "out_backprop"
- type_attr: "T"
- }
- output_arg {
- name: "output"
- type_attr: "T"
- }
- attr {
- name: "T"
- type: "type"
- allowed_values {
- list {
- type: DT_FLOAT
- type: DT_DOUBLE
- type: DT_INT32
- type: DT_UINT8
- type: DT_INT16
- type: DT_INT8
- type: DT_COMPLEX64
- type: DT_INT64
- type: DT_QINT8
- type: DT_QUINT8
- type: DT_QINT32
- type: DT_BFLOAT16
- type: DT_UINT16
- type: DT_COMPLEX128
- type: DT_HALF
- type: DT_UINT32
- type: DT_UINT64
- }
- }
- }
- attr {
- name: "data_format"
- type: "string"
- default_value {
- s: "NHWC"
- }
- allowed_values {
- list {
- s: "NHWC"
- s: "NCHW"
- s: "HWNC"
- s: "HWCN"
- }
- }
- }
-}
-op {
name: "BiasAddV1"
input_arg {
name: "value"
@@ -13457,81 +13227,6 @@ op {
}
}
op {
- name: "Conv2D"
- input_arg {
- name: "input"
- type_attr: "T"
- }
- input_arg {
- name: "filter"
- type_attr: "T"
- }
- output_arg {
- name: "output"
- type_attr: "T"
- }
- attr {
- name: "T"
- type: "type"
- allowed_values {
- list {
- type: DT_HALF
- type: DT_BFLOAT16
- type: DT_FLOAT
- type: DT_DOUBLE
- }
- }
- }
- attr {
- name: "strides"
- type: "list(int)"
- }
- attr {
- name: "use_cudnn_on_gpu"
- type: "bool"
- default_value {
- b: true
- }
- }
- attr {
- name: "padding"
- type: "string"
- allowed_values {
- list {
- s: "SAME"
- s: "VALID"
- }
- }
- }
- attr {
- name: "data_format"
- type: "string"
- default_value {
- s: "NHWC"
- }
- allowed_values {
- list {
- s: "NHWC"
- s: "NCHW"
- s: "HWNC"
- s: "HWCN"
- }
- }
- }
- attr {
- name: "dilations"
- type: "list(int)"
- default_value {
- list {
- i: 1
- i: 1
- i: 1
- i: 1
- }
- }
- }
-}
-op {
name: "Conv2DBackpropFilter"
input_arg {
name: "input"
@@ -13748,85 +13443,6 @@ op {
}
}
op {
- name: "Conv2DBackpropFilter"
- input_arg {
- name: "input"
- type_attr: "T"
- }
- input_arg {
- name: "filter_sizes"
- type: DT_INT32
- }
- input_arg {
- name: "out_backprop"
- type_attr: "T"
- }
- output_arg {
- name: "output"
- type_attr: "T"
- }
- attr {
- name: "T"
- type: "type"
- allowed_values {
- list {
- type: DT_HALF
- type: DT_BFLOAT16
- type: DT_FLOAT
- type: DT_DOUBLE
- }
- }
- }
- attr {
- name: "strides"
- type: "list(int)"
- }
- attr {
- name: "use_cudnn_on_gpu"
- type: "bool"
- default_value {
- b: true
- }
- }
- attr {
- name: "padding"
- type: "string"
- allowed_values {
- list {
- s: "SAME"
- s: "VALID"
- }
- }
- }
- attr {
- name: "data_format"
- type: "string"
- default_value {
- s: "NHWC"
- }
- allowed_values {
- list {
- s: "NHWC"
- s: "NCHW"
- s: "HWNC"
- s: "HWCN"
- }
- }
- }
- attr {
- name: "dilations"
- type: "list(int)"
- default_value {
- list {
- i: 1
- i: 1
- i: 1
- i: 1
- }
- }
- }
-}
-op {
name: "Conv2DBackpropInput"
input_arg {
name: "input_sizes"
@@ -14043,85 +13659,6 @@ op {
}
}
op {
- name: "Conv2DBackpropInput"
- input_arg {
- name: "input_sizes"
- type: DT_INT32
- }
- input_arg {
- name: "filter"
- type_attr: "T"
- }
- input_arg {
- name: "out_backprop"
- type_attr: "T"
- }
- output_arg {
- name: "output"
- type_attr: "T"
- }
- attr {
- name: "T"
- type: "type"
- allowed_values {
- list {
- type: DT_HALF
- type: DT_BFLOAT16
- type: DT_FLOAT
- type: DT_DOUBLE
- }
- }
- }
- attr {
- name: "strides"
- type: "list(int)"
- }
- attr {
- name: "use_cudnn_on_gpu"
- type: "bool"
- default_value {
- b: true
- }
- }
- attr {
- name: "padding"
- type: "string"
- allowed_values {
- list {
- s: "SAME"
- s: "VALID"
- }
- }
- }
- attr {
- name: "data_format"
- type: "string"
- default_value {
- s: "NHWC"
- }
- allowed_values {
- list {
- s: "NHWC"
- s: "NCHW"
- s: "HWNC"
- s: "HWCN"
- }
- }
- }
- attr {
- name: "dilations"
- type: "list(int)"
- default_value {
- list {
- i: 1
- i: 1
- i: 1
- i: 1
- }
- }
- }
-}
-op {
name: "Conv3D"
input_arg {
name: "input"
@@ -18852,74 +18389,6 @@ op {
}
}
op {
- name: "DepthwiseConv2dNative"
- input_arg {
- name: "input"
- type_attr: "T"
- }
- input_arg {
- name: "filter"
- type_attr: "T"
- }
- output_arg {
- name: "output"
- type_attr: "T"
- }
- attr {
- name: "T"
- type: "type"
- allowed_values {
- list {
- type: DT_HALF
- type: DT_BFLOAT16
- type: DT_FLOAT
- type: DT_DOUBLE
- }
- }
- }
- attr {
- name: "strides"
- type: "list(int)"
- }
- attr {
- name: "padding"
- type: "string"
- allowed_values {
- list {
- s: "SAME"
- s: "VALID"
- }
- }
- }
- attr {
- name: "data_format"
- type: "string"
- default_value {
- s: "NHWC"
- }
- allowed_values {
- list {
- s: "NHWC"
- s: "NCHW"
- s: "HWNC"
- s: "HWCN"
- }
- }
- }
- attr {
- name: "dilations"
- type: "list(int)"
- default_value {
- list {
- i: 1
- i: 1
- i: 1
- i: 1
- }
- }
- }
-}
-op {
name: "DepthwiseConv2dNativeBackpropFilter"
input_arg {
name: "input"
@@ -19158,78 +18627,6 @@ op {
}
}
op {
- name: "DepthwiseConv2dNativeBackpropFilter"
- input_arg {
- name: "input"
- type_attr: "T"
- }
- input_arg {
- name: "filter_sizes"
- type: DT_INT32
- }
- input_arg {
- name: "out_backprop"
- type_attr: "T"
- }
- output_arg {
- name: "output"
- type_attr: "T"
- }
- attr {
- name: "T"
- type: "type"
- allowed_values {
- list {
- type: DT_HALF
- type: DT_BFLOAT16
- type: DT_FLOAT
- type: DT_DOUBLE
- }
- }
- }
- attr {
- name: "strides"
- type: "list(int)"
- }
- attr {
- name: "padding"
- type: "string"
- allowed_values {
- list {
- s: "SAME"
- s: "VALID"
- }
- }
- }
- attr {
- name: "data_format"
- type: "string"
- default_value {
- s: "NHWC"
- }
- allowed_values {
- list {
- s: "NHWC"
- s: "NCHW"
- s: "HWNC"
- s: "HWCN"
- }
- }
- }
- attr {
- name: "dilations"
- type: "list(int)"
- default_value {
- list {
- i: 1
- i: 1
- i: 1
- i: 1
- }
- }
- }
-}
-op {
name: "DepthwiseConv2dNativeBackpropInput"
input_arg {
name: "input_sizes"
@@ -19468,78 +18865,6 @@ op {
}
}
op {
- name: "DepthwiseConv2dNativeBackpropInput"
- input_arg {
- name: "input_sizes"
- type: DT_INT32
- }
- input_arg {
- name: "filter"
- type_attr: "T"
- }
- input_arg {
- name: "out_backprop"
- type_attr: "T"
- }
- output_arg {
- name: "output"
- type_attr: "T"
- }
- attr {
- name: "T"
- type: "type"
- allowed_values {
- list {
- type: DT_HALF
- type: DT_BFLOAT16
- type: DT_FLOAT
- type: DT_DOUBLE
- }
- }
- }
- attr {
- name: "strides"
- type: "list(int)"
- }
- attr {
- name: "padding"
- type: "string"
- allowed_values {
- list {
- s: "SAME"
- s: "VALID"
- }
- }
- }
- attr {
- name: "data_format"
- type: "string"
- default_value {
- s: "NHWC"
- }
- allowed_values {
- list {
- s: "NHWC"
- s: "NCHW"
- s: "HWNC"
- s: "HWCN"
- }
- }
- }
- attr {
- name: "dilations"
- type: "list(int)"
- default_value {
- list {
- i: 1
- i: 1
- i: 1
- i: 1
- }
- }
- }
-}
-op {
name: "Dequantize"
input_arg {
name: "input"
@@ -26276,29 +25601,6 @@ op {
}
}
op {
- name: "IdentityDataset"
- input_arg {
- name: "input_dataset"
- type: DT_VARIANT
- }
- output_arg {
- name: "handle"
- type: DT_VARIANT
- }
- attr {
- name: "output_types"
- type: "list(type)"
- has_minimum: true
- minimum: 1
- }
- attr {
- name: "output_shapes"
- type: "list(shape)"
- has_minimum: true
- minimum: 1
- }
-}
-op {
name: "IdentityN"
input_arg {
name: "input"
@@ -27869,6 +27171,36 @@ op {
is_stateful: true
}
op {
+ name: "IteratorFromStringHandleV2"
+ input_arg {
+ name: "string_handle"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "resource_handle"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ }
+ is_stateful: true
+}
+op {
name: "IteratorGetNext"
input_arg {
name: "iterator"
@@ -27929,6 +27261,34 @@ op {
is_stateful: true
}
op {
+ name: "IteratorV2"
+ output_arg {
+ name: "handle"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
name: "L2Loss"
input_arg {
name: "t"
@@ -32199,85 +31559,6 @@ op {
}
}
op {
- name: "MaxPoolGrad"
- input_arg {
- name: "orig_input"
- type_attr: "T"
- }
- input_arg {
- name: "orig_output"
- type_attr: "T"
- }
- input_arg {
- name: "grad"
- type_attr: "T"
- }
- output_arg {
- name: "output"
- type_attr: "T"
- }
- attr {
- name: "ksize"
- type: "list(int)"
- has_minimum: true
- minimum: 4
- }
- attr {
- name: "strides"
- type: "list(int)"
- has_minimum: true
- minimum: 4
- }
- attr {
- name: "padding"
- type: "string"
- allowed_values {
- list {
- s: "SAME"
- s: "VALID"
- }
- }
- }
- attr {
- name: "data_format"
- type: "string"
- default_value {
- s: "NHWC"
- }
- allowed_values {
- list {
- s: "NHWC"
- s: "NCHW"
- s: "HWNC"
- s: "HWCN"
- }
- }
- }
- attr {
- name: "T"
- type: "type"
- default_value {
- type: DT_FLOAT
- }
- allowed_values {
- list {
- type: DT_FLOAT
- type: DT_DOUBLE
- type: DT_INT32
- type: DT_UINT8
- type: DT_INT16
- type: DT_INT8
- type: DT_INT64
- type: DT_BFLOAT16
- type: DT_UINT16
- type: DT_HALF
- type: DT_UINT32
- type: DT_UINT64
- }
- }
- }
-}
-op {
name: "MaxPoolGradGrad"
input_arg {
name: "orig_input"
@@ -32570,82 +31851,6 @@ op {
}
}
op {
- name: "MaxPoolGradGrad"
- input_arg {
- name: "orig_input"
- type_attr: "T"
- }
- input_arg {
- name: "orig_output"
- type_attr: "T"
- }
- input_arg {
- name: "grad"
- type_attr: "T"
- }
- output_arg {
- name: "output"
- type_attr: "T"
- }
- attr {
- name: "ksize"
- type: "list(int)"
- has_minimum: true
- minimum: 4
- }
- attr {
- name: "strides"
- type: "list(int)"
- has_minimum: true
- minimum: 4
- }
- attr {
- name: "padding"
- type: "string"
- allowed_values {
- list {
- s: "SAME"
- s: "VALID"
- }
- }
- }
- attr {
- name: "data_format"
- type: "string"
- default_value {
- s: "NHWC"
- }
- allowed_values {
- list {
- s: "NHWC"
- s: "NCHW"
- s: "HWNC"
- s: "HWCN"
- }
- }
- }
- attr {
- name: "T"
- type: "type"
- allowed_values {
- list {
- type: DT_FLOAT
- type: DT_DOUBLE
- type: DT_INT32
- type: DT_UINT8
- type: DT_INT16
- type: DT_INT8
- type: DT_INT64
- type: DT_BFLOAT16
- type: DT_UINT16
- type: DT_HALF
- type: DT_UINT32
- type: DT_UINT64
- }
- }
- }
-}
-op {
name: "MaxPoolGradGradV2"
input_arg {
name: "orig_input"
@@ -32922,78 +32127,6 @@ op {
}
}
op {
- name: "MaxPoolGradGradV2"
- input_arg {
- name: "orig_input"
- type_attr: "T"
- }
- input_arg {
- name: "orig_output"
- type_attr: "T"
- }
- input_arg {
- name: "grad"
- type_attr: "T"
- }
- input_arg {
- name: "ksize"
- type: DT_INT32
- }
- input_arg {
- name: "strides"
- type: DT_INT32
- }
- output_arg {
- name: "output"
- type_attr: "T"
- }
- attr {
- name: "padding"
- type: "string"
- allowed_values {
- list {
- s: "SAME"
- s: "VALID"
- }
- }
- }
- attr {
- name: "data_format"
- type: "string"
- default_value {
- s: "NHWC"
- }
- allowed_values {
- list {
- s: "NHWC"
- s: "NCHW"
- s: "HWNC"
- s: "HWCN"
- }
- }
- }
- attr {
- name: "T"
- type: "type"
- allowed_values {
- list {
- type: DT_FLOAT
- type: DT_DOUBLE
- type: DT_INT32
- type: DT_UINT8
- type: DT_INT16
- type: DT_INT8
- type: DT_INT64
- type: DT_BFLOAT16
- type: DT_UINT16
- type: DT_HALF
- type: DT_UINT32
- type: DT_UINT64
- }
- }
- }
-}
-op {
name: "MaxPoolGradGradWithArgmax"
input_arg {
name: "input"
@@ -33562,81 +32695,6 @@ op {
}
}
op {
- name: "MaxPoolGradV2"
- input_arg {
- name: "orig_input"
- type_attr: "T"
- }
- input_arg {
- name: "orig_output"
- type_attr: "T"
- }
- input_arg {
- name: "grad"
- type_attr: "T"
- }
- input_arg {
- name: "ksize"
- type: DT_INT32
- }
- input_arg {
- name: "strides"
- type: DT_INT32
- }
- output_arg {
- name: "output"
- type_attr: "T"
- }
- attr {
- name: "padding"
- type: "string"
- allowed_values {
- list {
- s: "SAME"
- s: "VALID"
- }
- }
- }
- attr {
- name: "data_format"
- type: "string"
- default_value {
- s: "NHWC"
- }
- allowed_values {
- list {
- s: "NHWC"
- s: "NCHW"
- s: "HWNC"
- s: "HWCN"
- }
- }
- }
- attr {
- name: "T"
- type: "type"
- default_value {
- type: DT_FLOAT
- }
- allowed_values {
- list {
- type: DT_FLOAT
- type: DT_DOUBLE
- type: DT_INT32
- type: DT_UINT8
- type: DT_INT16
- type: DT_INT8
- type: DT_INT64
- type: DT_BFLOAT16
- type: DT_UINT16
- type: DT_HALF
- type: DT_UINT32
- type: DT_UINT64
- }
- }
- }
-}
-op {
name: "MaxPoolGradWithArgmax"
input_arg {
name: "input"
@@ -36321,6 +35379,33 @@ op {
}
}
op {
+ name: "NonMaxSuppressionWithOverlaps"
+ input_arg {
+ name: "overlaps"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "scores"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "max_output_size"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "overlap_threshold"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "score_threshold"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "selected_indices"
+ type: DT_INT32
+ }
+}
+op {
name: "NotEqual"
input_arg {
name: "x"
@@ -58885,6 +57970,17 @@ op {
}
}
op {
+ name: "SinkDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+}
+op {
name: "Size"
input_arg {
name: "input"
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index 3782bc8796..c8bc11155a 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -644,6 +644,14 @@ REGISTER_OP("Iterator")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
+REGISTER_OP("IteratorV2")
+ .Output("handle: resource")
+ .Attr("shared_name: string")
+ .Attr("container: string")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
REGISTER_OP("AnonymousIterator")
.Output("handle: resource")
.Attr("output_types: list(type) >= 1")
@@ -721,6 +729,13 @@ REGISTER_OP("IteratorFromStringHandle")
.Attr("output_shapes: list(shape) >= 0 = []")
.SetShapeFn(shape_inference::ScalarShape);
+REGISTER_OP("IteratorFromStringHandleV2")
+ .Input("string_handle: string")
+ .Output("resource_handle: resource")
+ .Attr("output_types: list(type) >= 0 = []")
+ .Attr("output_shapes: list(shape) >= 0 = []")
+ .SetShapeFn(shape_inference::ScalarShape);
+
REGISTER_OP("SerializeIterator")
.Input("resource_handle: resource")
.Output("serialized: variant")
@@ -783,11 +798,9 @@ REGISTER_OP("DatasetToGraph")
.Output("graph: string")
.SetShapeFn(shape_inference::ScalarShape);
-REGISTER_OP("IdentityDataset")
+REGISTER_OP("SinkDataset")
.Input("input_dataset: variant")
.Output("handle: variant")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("OptimizeDataset")
diff --git a/tensorflow/core/ops/debug_ops.cc b/tensorflow/core/ops/debug_ops.cc
index 5aebdca1ea..2d9b4360de 100644
--- a/tensorflow/core/ops/debug_ops.cc
+++ b/tensorflow/core/ops/debug_ops.cc
@@ -20,7 +20,7 @@ limitations under the License.
namespace tensorflow {
-// EXPERIMENTAL: tfdbg debugger-inserted ops.
+// TensorFlow Debugger-inserted ops.
// These ops are used only internally by tfdbg. There is no API for users to
// direct create them. Users can create them indirectly by using
// RunOptions.debug_options during Session::Run() call. See tfdbg documentation
diff --git a/tensorflow/core/ops/functional_ops.cc b/tensorflow/core/ops/functional_ops.cc
index b49ad8e387..5f262db2ce 100644
--- a/tensorflow/core/ops/functional_ops.cc
+++ b/tensorflow/core/ops/functional_ops.cc
@@ -40,7 +40,11 @@ REGISTER_OP("SymbolicGradient")
if (types[i] == DT_RESOURCE) {
const std::vector<shape_inference::ShapeAndType>* handle_type =
c->input_handle_shapes_and_types(i);
- c->set_output(i, handle_type->at(0).shape);
+ if (handle_type != nullptr) {
+ c->set_output(i, handle_type->at(0).shape);
+ } else {
+ c->set_output(i, c->UnknownShape());
+ }
} else {
c->set_output(i, c->input(i));
}
diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc
index 87f4991134..50ced1ff73 100644
--- a/tensorflow/core/ops/image_ops.cc
+++ b/tensorflow/core/ops/image_ops.cc
@@ -709,4 +709,36 @@ REGISTER_OP("NonMaxSuppressionV3")
return Status::OK();
});
+REGISTER_OP("NonMaxSuppressionWithOverlaps")
+ .Input("overlaps: float")
+ .Input("scores: float")
+ .Input("max_output_size: int32")
+ .Input("overlap_threshold: float")
+ .Input("score_threshold: float")
+ .Output("selected_indices: int32")
+ .SetShapeFn([](InferenceContext* c) {
+ // Get inputs and validate ranks.
+ ShapeHandle overlaps;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &overlaps));
+ ShapeHandle scores;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores));
+ ShapeHandle max_output_size;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size));
+ ShapeHandle overlap_threshold;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &overlap_threshold));
+ ShapeHandle score_threshold;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &score_threshold));
+ // The boxes is a 2-D float Tensor of shape [num_boxes, 4].
+ DimensionHandle unused;
+ // The boxes[0] and scores[0] are both num_boxes.
+ TF_RETURN_IF_ERROR(
+ c->Merge(c->Dim(overlaps, 0), c->Dim(scores, 0), &unused));
+ // The boxes[1] is 4.
+ TF_RETURN_IF_ERROR(
+ c->Merge(c->Dim(overlaps, 0), c->Dim(overlaps, 1), &unused));
+
+ c->set_output(0, c->Vector(c->UnknownDim()));
+ return Status::OK();
+ });
+
} // namespace tensorflow
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 984779f9fa..76572061a4 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -2490,8 +2490,6 @@ op {
list {
s: "NHWC"
s: "NCHW"
- s: "HWNC"
- s: "HWCN"
}
}
}
@@ -2674,8 +2672,6 @@ op {
list {
s: "NHWC"
s: "NCHW"
- s: "HWNC"
- s: "HWCN"
}
}
}
@@ -3989,8 +3985,6 @@ op {
list {
s: "NHWC"
s: "NCHW"
- s: "HWNC"
- s: "HWCN"
}
}
}
@@ -4040,8 +4034,6 @@ op {
list {
s: "NHWC"
s: "NCHW"
- s: "HWNC"
- s: "HWCN"
}
}
}
@@ -5730,8 +5722,6 @@ op {
list {
s: "NHWC"
s: "NCHW"
- s: "HWNC"
- s: "HWCN"
}
}
}
@@ -5809,8 +5799,6 @@ op {
list {
s: "NHWC"
s: "NCHW"
- s: "HWNC"
- s: "HWCN"
}
}
}
@@ -5888,8 +5876,6 @@ op {
list {
s: "NHWC"
s: "NCHW"
- s: "HWNC"
- s: "HWCN"
}
}
}
@@ -8592,8 +8578,6 @@ op {
list {
s: "NHWC"
s: "NCHW"
- s: "HWNC"
- s: "HWCN"
}
}
}
@@ -8664,8 +8648,6 @@ op {
list {
s: "NHWC"
s: "NCHW"
- s: "HWNC"
- s: "HWCN"
}
}
}
@@ -8736,8 +8718,6 @@ op {
list {
s: "NHWC"
s: "NCHW"
- s: "HWNC"
- s: "HWCN"
}
}
}
@@ -12374,29 +12354,6 @@ op {
}
}
op {
- name: "IdentityDataset"
- input_arg {
- name: "input_dataset"
- type: DT_VARIANT
- }
- output_arg {
- name: "handle"
- type: DT_VARIANT
- }
- attr {
- name: "output_types"
- type: "list(type)"
- has_minimum: true
- minimum: 1
- }
- attr {
- name: "output_shapes"
- type: "list(shape)"
- has_minimum: true
- minimum: 1
- }
-}
-op {
name: "IdentityN"
input_arg {
name: "input"
@@ -13268,6 +13225,36 @@ op {
is_stateful: true
}
op {
+ name: "IteratorFromStringHandleV2"
+ input_arg {
+ name: "string_handle"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "resource_handle"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ }
+ is_stateful: true
+}
+op {
name: "IteratorGetNext"
input_arg {
name: "iterator"
@@ -13328,6 +13315,34 @@ op {
is_stateful: true
}
op {
+ name: "IteratorV2"
+ output_arg {
+ name: "handle"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
name: "L2Loss"
input_arg {
name: "t"
@@ -15474,8 +15489,6 @@ op {
list {
s: "NHWC"
s: "NCHW"
- s: "HWNC"
- s: "HWCN"
}
}
}
@@ -15553,8 +15566,6 @@ op {
list {
s: "NHWC"
s: "NCHW"
- s: "HWNC"
- s: "HWCN"
}
}
}
@@ -15625,8 +15636,6 @@ op {
list {
s: "NHWC"
s: "NCHW"
- s: "HWNC"
- s: "HWCN"
}
}
}
@@ -15768,8 +15777,6 @@ op {
list {
s: "NHWC"
s: "NCHW"
- s: "HWNC"
- s: "HWCN"
}
}
}
@@ -16989,6 +16996,33 @@ op {
}
}
op {
+ name: "NonMaxSuppressionWithOverlaps"
+ input_arg {
+ name: "overlaps"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "scores"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "max_output_size"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "overlap_threshold"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "score_threshold"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "selected_indices"
+ type: DT_INT32
+ }
+}
+op {
name: "NotEqual"
input_arg {
name: "x"
@@ -27370,6 +27404,17 @@ op {
}
}
op {
+ name: "SinkDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+}
+op {
name: "Size"
input_arg {
name: "input"
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc
index 4a8110dc2b..aa35e8a116 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system.cc
@@ -1560,6 +1560,7 @@ Status GcsFileSystem::CreateHttpRequest(std::unique_ptr<HttpRequest>* request) {
return Status::OK();
}
-REGISTER_FILE_SYSTEM("gs", RetryingGcsFileSystem);
-
} // namespace tensorflow
+
+// Initialize gcs_file_system
+REGISTER_FILE_SYSTEM("gs", ::tensorflow::RetryingGcsFileSystem);
diff --git a/tensorflow/core/platform/default/build_config/BUILD b/tensorflow/core/platform/default/build_config/BUILD
index c17e4810d5..da1f66dc67 100644
--- a/tensorflow/core/platform/default/build_config/BUILD
+++ b/tensorflow/core/platform/default/build_config/BUILD
@@ -146,7 +146,6 @@ cc_library(
"@farmhash_archive//:farmhash",
"@fft2d",
"@highwayhash//:sip_hash",
- "@png_archive//:png",
],
)
@@ -161,7 +160,7 @@ cc_library(
"@farmhash_archive//:farmhash",
"@fft2d",
"@highwayhash//:sip_hash",
- "@png_archive//:png",
+ "@zlib_archive//:zlib",
],
)
@@ -187,6 +186,15 @@ cc_library(
)
cc_library(
+ name = "png",
+ copts = tf_copts(),
+ deps = [
+ "@png_archive//:png",
+ "@zlib_archive//:zlib",
+ ],
+)
+
+cc_library(
name = "protos_cc_impl",
copts = tf_copts(),
deps = [
diff --git a/tensorflow/core/platform/env.h b/tensorflow/core/platform/env.h
index 9192f7ba10..e17ecc8c52 100644
--- a/tensorflow/core/platform/env.h
+++ b/tensorflow/core/platform/env.h
@@ -450,6 +450,6 @@ struct Register {
::tensorflow::register_file_system::Register<factory>(env, scheme)
#define REGISTER_FILE_SYSTEM(scheme, factory) \
- REGISTER_FILE_SYSTEM_ENV(Env::Default(), scheme, factory);
+ REGISTER_FILE_SYSTEM_ENV(::tensorflow::Env::Default(), scheme, factory);
#endif // TENSORFLOW_CORE_PLATFORM_ENV_H_
diff --git a/tensorflow/core/platform/numa.h b/tensorflow/core/platform/numa.h
new file mode 100644
index 0000000000..b1f08e4c4c
--- /dev/null
+++ b/tensorflow/core/platform/numa.h
@@ -0,0 +1,62 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_PLATFORM_NUMA_H_
+#define TENSORFLOW_CORE_PLATFORM_NUMA_H_
+
+#include "tensorflow/core/platform/platform.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace port {
+
+// Returns true iff NUMA functions are supported.
+bool NUMAEnabled();
+
+// Returns the number of NUMA nodes present with respect to CPU operations.
+// Typically this will be the number of sockets where some RAM has greater
+// affinity with one socket than another.
+int NUMANumNodes();
+
+static const int kNUMANoAffinity = -1;
+
+// If possible sets affinity of the current thread to the specified NUMA node.
+// If node == kNUMANoAffinity removes affinity to any particular node.
+void NUMASetThreadNodeAffinity(int node);
+
+// Returns NUMA node affinity of the current thread, kNUMANoAffinity if none.
+int NUMAGetThreadNodeAffinity();
+
+// Like AlignedMalloc, but allocates memory with affinity to the specified NUMA
+// node.
+//
+// Notes:
+// 1. node must be >= 0 and < NUMANumNodes.
+// 1. minimum_alignment must a factor of system page size, the memory
+// returned will be page-aligned.
+// 2. This function is likely significantly slower than AlignedMalloc
+// and should not be used for lots of small allocations. It makes more
+// sense as a backing allocator for BFCAllocator, PoolAllocator, or similar.
+void* NUMAMalloc(int node, size_t size, int minimum_alignment);
+
+// Memory allocated by NUMAMalloc must be freed via NUMAFree.
+void NUMAFree(void* ptr, size_t size);
+
+// Returns NUMA node affinity of memory address, kNUMANoAffinity if none.
+int NUMAGetMemAffinity(const void* ptr);
+
+} // namespace port
+} // namespace tensorflow
+#endif // TENSORFLOW_CORE_PLATFORM_NUMA_H_
diff --git a/tensorflow/core/platform/numa_test.cc b/tensorflow/core/platform/numa_test.cc
new file mode 100644
index 0000000000..8b39ecd59c
--- /dev/null
+++ b/tensorflow/core/platform/numa_test.cc
@@ -0,0 +1,61 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/platform/numa.h"
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace internal {
+
+TEST(Numa, NumNodes) {
+ if (port::NUMAEnabled()) {
+ EXPECT_GE(port::NUMANumNodes(), 1);
+ }
+}
+
+TEST(Numa, Malloc) {
+ if (port::NUMAEnabled()) {
+ int num_nodes = port::NUMANumNodes();
+ for (int request_node = 0; request_node < num_nodes; ++request_node) {
+ void* ptr = port::NUMAMalloc(request_node, 8, 0);
+ EXPECT_NE(ptr, nullptr);
+ // Affinity cannot be tested until page is touched, so save a value.
+ *(reinterpret_cast<int*>(ptr)) = 0;
+ int affinity_node = port::NUMAGetMemAffinity(ptr);
+ EXPECT_EQ(affinity_node, request_node);
+ port::NUMAFree(ptr, 8);
+ }
+ }
+}
+
+TEST(Numa, SetNodeAffinity) {
+ // NOTE(tucker): This test is not reliable when executed under tap because
+ // the virtual machine may not have access to all of the availble NUMA
+ // nodes. Not sure what to do about that.
+ EXPECT_EQ(-1, port::NUMAGetThreadNodeAffinity());
+ if (port::NUMAEnabled()) {
+ int num_nodes = port::NUMANumNodes();
+ for (int request_node = 0; request_node < num_nodes; ++request_node) {
+ port::NUMASetThreadNodeAffinity(request_node);
+ int affinity_node = port::NUMAGetThreadNodeAffinity();
+ EXPECT_EQ(affinity_node, request_node);
+ }
+ }
+}
+
+} // namespace internal
+} // namespace tensorflow
diff --git a/tensorflow/core/platform/posix/port.cc b/tensorflow/core/platform/posix/port.cc
index 708f32ba80..1939cf72fb 100644
--- a/tensorflow/core/platform/posix/port.cc
+++ b/tensorflow/core/platform/posix/port.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mem.h"
+#include "tensorflow/core/platform/numa.h"
#include "tensorflow/core/platform/snappy.h"
#include "tensorflow/core/platform/types.h"
@@ -79,6 +80,19 @@ int NumHyperthreadsPerCore() {
return (ht_per_core > 0) ? ht_per_core : 1;
}
+bool NUMAEnabled() {
+ // Not yet implemented: coming soon.
+ return false;
+}
+
+int NUMANumNodes() { return 1; }
+
+void NUMASetThreadNodeAffinity(int node) {}
+
+int NUMAGetThreadNodeAffinity() {
+ return kNUMANoAffinity;
+}
+
void* AlignedMalloc(size_t size, int minimum_alignment) {
#if defined(__ANDROID__)
return memalign(minimum_alignment, size);
@@ -128,6 +142,16 @@ void Free(void* ptr) {
#endif
}
+void* NUMAMalloc(int node, size_t size, int minimum_alignment) {
+ return AlignedMalloc(size, minimum_alignment);
+}
+
+void NUMAFree(void* ptr, size_t size) { Free(ptr); }
+
+int NUMAGetMemAffinity(const void* addr) {
+ return kNUMANoAffinity;
+}
+
void MallocExtension_ReleaseToSystem(std::size_t num_bytes) {
// No-op.
}
diff --git a/tensorflow/core/platform/profile_utils/cpu_utils.cc b/tensorflow/core/platform/profile_utils/cpu_utils.cc
index 02de7d1362..b0136b52f4 100644
--- a/tensorflow/core/platform/profile_utils/cpu_utils.cc
+++ b/tensorflow/core/platform/profile_utils/cpu_utils.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/platform/profile_utils/cpu_utils.h"
+#include <fstream>
#include <limits>
#include <mutex>
@@ -67,22 +68,32 @@ static ICpuUtilsHelper* cpu_utils_helper_instance_ = nullptr;
#if defined(__ANDROID__)
return GetCpuUtilsHelperSingletonInstance().CalculateCpuFrequency();
#elif defined(__linux__)
- double bogomips;
- FILE* fp = popen("grep '^bogomips' /proc/cpuinfo | head -1", "r");
- if (fp == nullptr) {
- return INVALID_FREQUENCY;
- }
- const int retval_of_bogomips = fscanf(fp, "bogomips : %lf", &bogomips);
- if (retval_of_bogomips <= 0) {
+ // Read the contents of /proc/cpuinfo.
+ std::ifstream cpuinfo("/proc/cpuinfo");
+ if (!cpuinfo) {
+ LOG(WARNING) << "Failed to open /proc/cpuinfo";
return INVALID_FREQUENCY;
}
- pclose(fp);
- const double freq_ghz = bogomips / 1000.0 / 2.0;
- if (retval_of_bogomips != 1 || freq_ghz < 0.01) {
- LOG(WARNING) << "Failed to get CPU frequency: " << freq_ghz << " Hz";
- return INVALID_FREQUENCY;
+ string line;
+ while (std::getline(cpuinfo, line)) {
+ double bogomips;
+ const int retval_of_bogomips =
+ sscanf(line.c_str(), "bogomips : %lf", &bogomips);
+ if (retval_of_bogomips > 0) {
+ const double freq_ghz = bogomips / 1000.0 / 2.0;
+ if (retval_of_bogomips != 1 || freq_ghz < 0.01) {
+ LOG(WARNING) << "Failed to get CPU frequency: " << freq_ghz << " Hz";
+ return INVALID_FREQUENCY;
+ }
+ const int64 freq_n =
+ static_cast<int64>(freq_ghz * 1000.0 * 1000.0 * 1000.0);
+ LOG(INFO) << "CPU Frequency: " << freq_n << " Hz";
+ return freq_n;
+ }
}
- return static_cast<int64>(freq_ghz * 1000.0 * 1000.0 * 1000.0);
+ LOG(WARNING) << "Failed to find bogomips in /proc/cpuinfo; cannot determine "
+ "CPU frequency";
+ return INVALID_FREQUENCY;
#elif defined(__APPLE__)
int64 freq_hz;
FILE* fp =
diff --git a/tensorflow/core/platform/s3/s3_crypto.cc b/tensorflow/core/platform/s3/s3_crypto.cc
new file mode 100644
index 0000000000..d7062a59d2
--- /dev/null
+++ b/tensorflow/core/platform/s3/s3_crypto.cc
@@ -0,0 +1,113 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/platform/s3/s3_crypto.h"
+#include <openssl/hmac.h>
+#include <openssl/sha.h>
+
+#include <aws/core/utils/crypto/HashResult.h>
+#include <aws/s3/S3Client.h>
+
+namespace tensorflow {
+
+class S3Sha256HMACOpenSSLImpl : public Aws::Utils::Crypto::HMAC {
+ public:
+ S3Sha256HMACOpenSSLImpl() {}
+
+ virtual ~S3Sha256HMACOpenSSLImpl() = default;
+
+ virtual Aws::Utils::Crypto::HashResult Calculate(
+ const Aws::Utils::ByteBuffer& toSign,
+ const Aws::Utils::ByteBuffer& secret) override {
+ unsigned int length = SHA256_DIGEST_LENGTH;
+ Aws::Utils::ByteBuffer digest(length);
+ memset(digest.GetUnderlyingData(), 0, length);
+
+ HMAC_CTX ctx;
+ HMAC_CTX_init(&ctx);
+
+ HMAC_Init_ex(&ctx, secret.GetUnderlyingData(),
+ static_cast<int>(secret.GetLength()), EVP_sha256(), NULL);
+ HMAC_Update(&ctx, toSign.GetUnderlyingData(), toSign.GetLength());
+ HMAC_Final(&ctx, digest.GetUnderlyingData(), &length);
+ HMAC_CTX_cleanup(&ctx);
+
+ return Aws::Utils::Crypto::HashResult(std::move(digest));
+ }
+};
+
+class S3Sha256OpenSSLImpl : public Aws::Utils::Crypto::Hash {
+ public:
+ S3Sha256OpenSSLImpl() {}
+
+ virtual ~S3Sha256OpenSSLImpl() = default;
+
+ virtual Aws::Utils::Crypto::HashResult Calculate(
+ const Aws::String& str) override {
+ SHA256_CTX sha256;
+ SHA256_Init(&sha256);
+ SHA256_Update(&sha256, str.data(), str.size());
+
+ Aws::Utils::ByteBuffer hash(SHA256_DIGEST_LENGTH);
+ SHA256_Final(hash.GetUnderlyingData(), &sha256);
+
+ return Aws::Utils::Crypto::HashResult(std::move(hash));
+ }
+
+ virtual Aws::Utils::Crypto::HashResult Calculate(
+ Aws::IStream& stream) override {
+ SHA256_CTX sha256;
+ SHA256_Init(&sha256);
+
+ auto currentPos = stream.tellg();
+ if (currentPos == std::streampos(std::streamoff(-1))) {
+ currentPos = 0;
+ stream.clear();
+ }
+
+ stream.seekg(0, stream.beg);
+
+ char streamBuffer
+ [Aws::Utils::Crypto::Hash::INTERNAL_HASH_STREAM_BUFFER_SIZE];
+ while (stream.good()) {
+ stream.read(streamBuffer,
+ Aws::Utils::Crypto::Hash::INTERNAL_HASH_STREAM_BUFFER_SIZE);
+ auto bytesRead = stream.gcount();
+
+ if (bytesRead > 0) {
+ SHA256_Update(&sha256, streamBuffer, static_cast<size_t>(bytesRead));
+ }
+ }
+
+ stream.clear();
+ stream.seekg(currentPos, stream.beg);
+
+ Aws::Utils::ByteBuffer hash(SHA256_DIGEST_LENGTH);
+ SHA256_Final(hash.GetUnderlyingData(), &sha256);
+
+ return Aws::Utils::Crypto::HashResult(std::move(hash));
+ }
+};
+
+std::shared_ptr<Aws::Utils::Crypto::Hash>
+S3SHA256Factory::CreateImplementation() const {
+ return Aws::MakeShared<S3Sha256OpenSSLImpl>(S3CryptoAllocationTag);
+}
+
+std::shared_ptr<Aws::Utils::Crypto::HMAC>
+S3SHA256HmacFactory::CreateImplementation() const {
+ return Aws::MakeShared<S3Sha256HMACOpenSSLImpl>(S3CryptoAllocationTag);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/platform/s3/s3_crypto.h b/tensorflow/core/platform/s3/s3_crypto.h
new file mode 100644
index 0000000000..e376b8b0c0
--- /dev/null
+++ b/tensorflow/core/platform/s3/s3_crypto.h
@@ -0,0 +1,35 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <aws/core/Aws.h>
+#include <aws/core/utils/crypto/Factories.h>
+#include <aws/core/utils/crypto/HMAC.h>
+#include <aws/core/utils/crypto/Hash.h>
+
+namespace tensorflow {
+static const char* S3CryptoAllocationTag = "S3CryptoAllocation";
+
+class S3SHA256Factory : public Aws::Utils::Crypto::HashFactory {
+ public:
+ std::shared_ptr<Aws::Utils::Crypto::Hash> CreateImplementation()
+ const override;
+};
+
+class S3SHA256HmacFactory : public Aws::Utils::Crypto::HMACFactory {
+ public:
+ std::shared_ptr<Aws::Utils::Crypto::HMAC> CreateImplementation()
+ const override;
+};
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/lite/java/src/main/native/duration_utils_jni.cc b/tensorflow/core/platform/vmodule_benchmark_test.cc
index 0e08a04370..0f9e75bf9c 100644
--- a/tensorflow/contrib/lite/java/src/main/native/duration_utils_jni.cc
+++ b/tensorflow/core/platform/vmodule_benchmark_test.cc
@@ -13,26 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <jni.h>
-#include <time.h>
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test_benchmark.h"
-namespace tflite {
+namespace tensorflow {
-// Gets the elapsed wall-clock timespec.
-timespec getCurrentTime() {
- timespec time;
- clock_gettime(CLOCK_MONOTONIC, &time);
- return time;
+static void BM_DisabledVlog(int iters) {
+ for (int i = 0; i < iters; ++i) {
+ VLOG(1) << "Testing VLOG(1)!";
+ }
}
+BENCHMARK(BM_DisabledVlog);
-// Computes the time diff from two timespecs. Returns '-1' if 'stop' is earlier
-// than 'start'.
-jlong timespec_diff_nanoseconds(struct timespec* start, struct timespec* stop) {
- jlong result = stop->tv_sec - start->tv_sec;
- if (result < 0) return -1;
- result = 1000000000 * result + (stop->tv_nsec - start->tv_nsec);
- if (result < 0) return -1;
- return result;
-}
-
-} // namespace tflite
+} // namespace tensorflow
diff --git a/tensorflow/core/platform/vmodule_test.cc b/tensorflow/core/platform/vmodule_test.cc
new file mode 100644
index 0000000000..47b4b2e0e7
--- /dev/null
+++ b/tensorflow/core/platform/vmodule_test.cc
@@ -0,0 +1,117 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Test that popens a child process with the VLOG-ing environment variable set
+// for the logging framework, and observes VLOG_IS_ON and VLOG macro output.
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/platform.h"
+#include "tensorflow/core/platform/test.h"
+
+#include <string.h>
+
+namespace tensorflow {
+namespace {
+
+int RealMain(const char* argv0, bool do_vlog) {
+ if (do_vlog) {
+#if !defined(PLATFORM_GOOGLE)
+ // Note, we only test this when !defined(PLATFORM_GOOGLE) because
+ // VmoduleActivated doesn't exist in that implementation.
+ //
+ // Also, we call this internal API to simulate what would happen if
+ // differently-named translation units attempted to VLOG, so we don't need
+ // to create dummy translation unit files.
+ bool ok = internal::LogMessage::VmoduleActivated("vmodule_test.cc", 7) &&
+ internal::LogMessage::VmoduleActivated("shoobadooba.h", 3);
+ if (!ok) {
+ fprintf(stderr, "vmodule activated levels not as expected.\n");
+ return EXIT_FAILURE;
+ }
+#endif
+
+ // Print info on which VLOG levels are activated.
+ fprintf(stderr, "VLOG_IS_ON(8)? %d\n", VLOG_IS_ON(8));
+ fprintf(stderr, "VLOG_IS_ON(7)? %d\n", VLOG_IS_ON(7));
+ fprintf(stderr, "VLOG_IS_ON(6)? %d\n", VLOG_IS_ON(6));
+ // Do some VLOG-ing.
+ VLOG(8) << "VLOG(8)";
+ VLOG(7) << "VLOG(7)";
+ VLOG(6) << "VLOG(6)";
+ LOG(INFO) << "INFO";
+ return EXIT_SUCCESS;
+ }
+
+ // Popen the child process.
+ std::string command = std::string(argv0);
+#if defined(PLATFORM_GOOGLE)
+ command = command + " do_vlog --vmodule=vmodule_test=7 --alsologtostderr";
+#else
+ command =
+ "TF_CPP_VMODULE=vmodule_test=7,shoobadooba=3 " + command + " do_vlog";
+#endif
+ command += " 2>&1";
+ fprintf(stderr, "Running: \"%s\"\n", command.c_str());
+ FILE* f = popen(command.c_str(), "r");
+ if (f == nullptr) {
+ fprintf(stderr, "Failed to popen child: %s\n", strerror(errno));
+ return EXIT_FAILURE;
+ }
+
+ // Read data from the child's stdout.
+ constexpr int kBufferSizeBytes = 4096;
+ char buffer[kBufferSizeBytes];
+ size_t result = fread(buffer, sizeof(buffer[0]), kBufferSizeBytes - 1, f);
+ if (result == 0) {
+ fprintf(stderr, "Failed to read from child stdout: %zu %s\n", result,
+ strerror(errno));
+ return EXIT_FAILURE;
+ }
+ buffer[result] = '\0';
+ int status = pclose(f);
+ if (status == -1) {
+ fprintf(stderr, "Failed to close popen child: %s\n", strerror(errno));
+ return EXIT_FAILURE;
+ }
+
+ // Check output is as expected.
+ const char kExpected[] =
+ "VLOG_IS_ON(8)? 0\nVLOG_IS_ON(7)? 1\nVLOG_IS_ON(6)? 1\n";
+ if (strstr(buffer, kExpected) == nullptr) {
+ fprintf(stderr, "error: unexpected output from child: \"%.*s\"\n",
+ kBufferSizeBytes, buffer);
+ return EXIT_FAILURE;
+ }
+ bool ok = strstr(buffer, "VLOG(7)\n") != nullptr &&
+ strstr(buffer, "VLOG(6)\n") != nullptr &&
+ strstr(buffer, "VLOG(8)\n") == nullptr;
+ if (!ok) {
+ fprintf(stderr, "error: VLOG output not as expected: \"%.*s\"\n",
+ kBufferSizeBytes, buffer);
+ return EXIT_FAILURE;
+ }
+
+ // Success!
+ return EXIT_SUCCESS;
+}
+
+} // namespace
+} // namespace tensorflow
+
+int main(int argc, char** argv) {
+ testing::InitGoogleTest(&argc, argv);
+ bool do_vlog = argc >= 2 && strcmp(argv[1], "do_vlog") == 0;
+ return tensorflow::RealMain(argv[0], do_vlog);
+}
diff --git a/tensorflow/core/protobuf/debug.proto b/tensorflow/core/protobuf/debug.proto
index 499900f965..811cf406b9 100644
--- a/tensorflow/core/protobuf/debug.proto
+++ b/tensorflow/core/protobuf/debug.proto
@@ -7,7 +7,7 @@ option java_multiple_files = true;
option java_package = "org.tensorflow.framework";
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf";
-// EXPERIMENTAL. Option for watching a node.
+// Option for watching a node in TensorFlow Debugger (tfdbg).
message DebugTensorWatch {
// Name of the node to watch.
string node_name = 1;
@@ -51,7 +51,7 @@ message DebugTensorWatch {
bool tolerate_debug_op_creation_failures = 5;
}
-// EXPERIMENTAL. Options for initializing DebuggerState.
+// Options for initializing DebuggerState in TensorFlow Debugger (tfdbg).
message DebugOptions {
// Debugging options
repeated DebugTensorWatch debug_tensor_watch_opts = 4;
diff --git a/tensorflow/core/public/session.h b/tensorflow/core/public/session.h
index d58c877cfd..cc8596ef3d 100644
--- a/tensorflow/core/public/session.h
+++ b/tensorflow/core/public/session.h
@@ -237,7 +237,7 @@ class Session {
/// If session creation succeeds, the new `Session` will be stored in
/// `*out_session`, the caller will take ownership of the returned
/// `*out_session`, and this function will return `OK()`. Otherwise, this
-/// function will return an error status.
+/// function will return an error status and set *out_session to nullptr.
Status NewSession(const SessionOptions& options, Session** out_session);
/// \brief Resets resource containers associated with a target.
diff --git a/tensorflow/core/util/sparse/dim_comparator.h b/tensorflow/core/util/sparse/dim_comparator.h
index b773b33008..0782e7e1a8 100644
--- a/tensorflow/core/util/sparse/dim_comparator.h
+++ b/tensorflow/core/util/sparse/dim_comparator.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_UTIL_SPARSE_DIM_COMPARATOR_H_
-#define TENSORFLOW_UTIL_SPARSE_DIM_COMPARATOR_H_
+#ifndef TENSORFLOW_CORE_UTIL_SPARSE_DIM_COMPARATOR_H_
+#define TENSORFLOW_CORE_UTIL_SPARSE_DIM_COMPARATOR_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/bounds_check.h"
@@ -49,11 +49,11 @@ class DimComparator {
DimComparator(const TTypes<int64>::Matrix& ix, const VarDimArray& order,
const VarDimArray& shape)
: ix_(ix), order_(order), dims_(shape.size()) {
- CHECK_GT(order.size(), size_t{0}) << "Must order using at least one index";
- CHECK_LE(order.size(), shape.size()) << "Can only sort up to dims";
+ DCHECK_GT(order.size(), size_t{0}) << "Must order using at least one index";
+ DCHECK_LE(order.size(), shape.size()) << "Can only sort up to dims";
for (size_t d = 0; d < order.size(); ++d) {
- CHECK_GE(order[d], 0);
- CHECK_LT(order[d], shape.size());
+ DCHECK_GE(order[d], 0);
+ DCHECK_LT(order[d], shape.size());
}
}
@@ -97,7 +97,7 @@ class FixedDimComparator : DimComparator {
FixedDimComparator(const TTypes<int64>::Matrix& ix, const VarDimArray& order,
const VarDimArray& shape)
: DimComparator(ix, order, shape) {
- CHECK_EQ(order.size(), ORDER_DIM);
+ DCHECK_EQ(order.size(), ORDER_DIM);
}
inline bool operator()(const int64 i, const int64 j) const {
bool value = false;
@@ -116,4 +116,4 @@ class FixedDimComparator : DimComparator {
} // namespace sparse
} // namespace tensorflow
-#endif // TENSORFLOW_UTIL_SPARSE_DIM_COMPARATOR_H_
+#endif // TENSORFLOW_CORE_UTIL_SPARSE_DIM_COMPARATOR_H_
diff --git a/tensorflow/core/util/sparse/group_iterator.h b/tensorflow/core/util/sparse/group_iterator.h
index fb70318078..3fa8cb6116 100644
--- a/tensorflow/core/util/sparse/group_iterator.h
+++ b/tensorflow/core/util/sparse/group_iterator.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_UTIL_SPARSE_GROUP_ITERATOR_H_
-#define TENSORFLOW_UTIL_SPARSE_GROUP_ITERATOR_H_
+#ifndef TENSORFLOW_CORE_UTIL_SPARSE_GROUP_ITERATOR_H_
+#define TENSORFLOW_CORE_UTIL_SPARSE_GROUP_ITERATOR_H_
#include <vector>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -143,4 +143,4 @@ typename TTypes<T>::UnalignedVec Group::values() const {
} // namespace sparse
} // namespace tensorflow
-#endif // TENSORFLOW_UTIL_SPARSE_GROUP_ITERATOR_H_
+#endif // TENSORFLOW_CORE_UTIL_SPARSE_GROUP_ITERATOR_H_
diff --git a/tensorflow/core/util/sparse/sparse_tensor.h b/tensorflow/core/util/sparse/sparse_tensor.h
index 258ee418c1..0f04b65f60 100644
--- a/tensorflow/core/util/sparse/sparse_tensor.h
+++ b/tensorflow/core/util/sparse/sparse_tensor.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_UTIL_SPARSE_SPARSE_TENSOR_H_
-#define TENSORFLOW_UTIL_SPARSE_SPARSE_TENSOR_H_
+#ifndef TENSORFLOW_CORE_UTIL_SPARSE_SPARSE_TENSOR_H_
+#define TENSORFLOW_CORE_UTIL_SPARSE_SPARSE_TENSOR_H_
#include <limits>
#include <numeric>
@@ -26,8 +26,10 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/sparse/dim_comparator.h"
@@ -41,32 +43,88 @@ class SparseTensor {
typedef typename gtl::ArraySlice<int64> VarDimArray;
typedef typename gtl::InlinedVector<int64, 8> ShapeArray;
+ static Status Create(Tensor ix, Tensor vals, const VarDimArray shape,
+ const VarDimArray order, SparseTensor* result) {
+ if (ix.dtype() != DT_INT64) {
+ return Status(
+ error::INVALID_ARGUMENT,
+ strings::StrCat("indices must be type int64 but got: ", ix.dtype()));
+ }
+ if (!TensorShapeUtils::IsVector(vals.shape())) {
+ return Status(error::INVALID_ARGUMENT,
+ strings::StrCat("vals must be a vec, but got: ",
+ vals.shape().DebugString()));
+ }
+ if (ix.shape().dim_size(0) != vals.shape().dim_size(0)) {
+ return Status(error::INVALID_ARGUMENT,
+ strings::StrCat("indices and values rows (indexing "
+ "dimension) must match. (indices = ",
+ ix.shape().dim_size(0), ", values = ",
+ vals.shape().dim_size(0), ")"));
+ }
+ int dims;
+ TF_RETURN_IF_ERROR(GetDimsFromIx(ix, &dims));
+ if (order.size() != dims) {
+ return Status(error::INVALID_ARGUMENT,
+ "Order length must be SparseTensor rank.");
+ }
+ if (shape.size() != dims) {
+ return Status(error::INVALID_ARGUMENT,
+ "Shape rank must be SparseTensor rank.");
+ }
+
+ *result = SparseTensor(ix, vals, shape, order);
+ return Status();
+ }
+
+ static Status Create(Tensor ix, Tensor vals, const TensorShape& shape,
+ SparseTensor* result) {
+ return Create(ix, vals, TensorShapeToVector(shape),
+ UndefinedOrder(TensorShapeToVector(shape)), result);
+ }
+
+ static Status Create(Tensor ix, Tensor vals, const VarDimArray shape,
+ SparseTensor* result) {
+ return Create(ix, vals, shape, UndefinedOrder(shape), result);
+ }
+
+ static Status Create(Tensor ix, Tensor vals, const TensorShape& shape,
+ const VarDimArray order, SparseTensor* result) {
+ return Create(ix, vals, TensorShapeToVector(shape), order, result);
+ }
+
+ SparseTensor() : dims_(0) {}
+
+ // DEPRECATED: use Create() functions instead of constructors directly.
SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape)
: SparseTensor(ix, vals, TensorShapeToVector(shape),
UndefinedOrder(TensorShapeToVector(shape))) {}
+ // DEPRECATED: use Create() functions instead of constructors directly.
SparseTensor(Tensor ix, Tensor vals, const VarDimArray shape)
: SparseTensor(ix, vals, shape, UndefinedOrder(shape)) {}
+ // DEPRECATED: use Create() functions instead of constructors directly.
SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape,
const VarDimArray order)
: SparseTensor(ix, vals, TensorShapeToVector(shape), order) {}
+ // DEPRECATED: use Create() functions instead of constructors directly.
SparseTensor(Tensor ix, Tensor vals, const VarDimArray shape,
const VarDimArray order)
: ix_(ix),
vals_(vals),
shape_(shape.begin(), shape.end()),
order_(order.begin(), order.end()),
- dims_(GetDimsFromIx(ix)) {
- CHECK_EQ(ix.dtype(), DT_INT64)
+ dims_(UnsafeGetDimsFromIx(ix)) {
+ DCHECK_EQ(ix.dtype(), DT_INT64)
<< "indices must be type int64 but got: " << ix.dtype();
- CHECK(TensorShapeUtils::IsVector(vals.shape()))
+ DCHECK(TensorShapeUtils::IsVector(vals.shape()))
<< "vals must be a vec, but got: " << vals.shape().DebugString();
- CHECK_EQ(ix.shape().dim_size(0), vals.shape().dim_size(0))
+ DCHECK_EQ(ix.shape().dim_size(0), vals.shape().dim_size(0))
<< "indices and values rows (indexing dimension) must match.";
- CHECK_EQ(order.size(), dims_) << "Order length must be SparseTensor rank.";
- CHECK_EQ(shape.size(), dims_) << "Shape rank must be SparseTensor rank.";
+ DCHECK_EQ(order.size(), dims_) << "Order length must be SparseTensor rank.";
+ DCHECK_EQ(shape.size(), dims_) << "Shape rank must be SparseTensor rank.";
}
SparseTensor(const SparseTensor& other)
@@ -81,6 +139,16 @@ class SparseTensor {
vals_ = other.vals_;
shape_ = other.shape_;
order_ = other.order_;
+ dims_ = other.dims_;
+ return *this;
+ }
+
+ SparseTensor& operator=(SparseTensor&& other) {
+ ix_ = std::move(other.ix_);
+ vals_ = std::move(other.vals_);
+ shape_ = std::move(other.shape_);
+ order_ = std::move(other.order_);
+ dims_ = std::move(other.dims_);
return *this;
}
@@ -126,11 +194,11 @@ class SparseTensor {
//
// See the README.md in this directory for more usage information.
GroupIterable group(const VarDimArray& group_ix) const {
- CHECK_LE(group_ix.size(), dims_);
+ DCHECK_LE(group_ix.size(), dims_);
for (std::size_t di = 0; di < group_ix.size(); ++di) {
- CHECK_GE(group_ix[di], 0) << "Group dimension out of range";
- CHECK_LT(group_ix[di], dims_) << "Group dimension out of range";
- CHECK_EQ(group_ix[di], order_[di])
+ DCHECK_GE(group_ix[di], 0) << "Group dimension out of range";
+ DCHECK_LT(group_ix[di], dims_) << "Group dimension out of range";
+ DCHECK_EQ(group_ix[di], order_[di])
<< "Group dimension does not match sorted order";
}
return GroupIterable(ix_, vals_, dims_, group_ix);
@@ -166,9 +234,16 @@ class SparseTensor {
// isn't an integer multiple of split_dim, we add one extra dimension for
// each slice.
template <typename T>
+ static Status Split(const SparseTensor& tensor, const int split_dim,
+ const int num_split, std::vector<SparseTensor>* result);
+
+ // DEPRECATED: use the form of Split() that takes an output pointer and
+ // returns a status instead.
+ template <typename T>
static std::vector<SparseTensor> Split(const SparseTensor& tensor,
const int split_dim,
- const int num_split);
+ const int num_split,
+ Status* status = nullptr);
// Slice() will slice the input SparseTensor into a SparseTensor based on
// specified start and size. Both start and size are 1-D array with each
@@ -189,9 +264,18 @@ class SparseTensor {
}
private:
- static int GetDimsFromIx(const Tensor& ix) {
- CHECK(TensorShapeUtils::IsMatrix(ix.shape()))
- << "indices must be a matrix, but got: " << ix.shape().DebugString();
+ static Status GetDimsFromIx(const Tensor& ix, int* result) {
+ if (!TensorShapeUtils::IsMatrix(ix.shape())) {
+ return Status(error::INVALID_ARGUMENT,
+ strings::StrCat("indices must be a matrix, but got: ",
+ ix.shape().DebugString()));
+ }
+ *result = UnsafeGetDimsFromIx(ix);
+ return Status();
+ }
+
+ static int UnsafeGetDimsFromIx(const Tensor& ix) {
+ DCHECK(TensorShapeUtils::IsMatrix(ix.shape()));
return ix.dim_size(1);
}
@@ -251,8 +335,8 @@ class SparseTensor {
// Helper for Split() that returns the slice index.
static inline int GetSliceIndex(const int dim, const int split_size,
const int residual) {
- CHECK_GT(split_size, 0);
- CHECK_GE(dim, 0);
+ DCHECK_GT(split_size, 0);
+ DCHECK_GE(dim, 0);
if (residual == 0) return dim / split_size;
const int offset = residual * (split_size + 1);
if (dim < offset) {
@@ -265,8 +349,8 @@ class SparseTensor {
// Helper for Split() that returns the dimension in the slice.
static inline int GetDimensionInSlice(const int dim, const int split_size,
const int residual) {
- CHECK_GT(split_size, 0);
- CHECK_GE(dim, 0);
+ DCHECK_GT(split_size, 0);
+ DCHECK_GE(dim, 0);
if (residual == 0) return dim % split_size;
const int offset = residual * (split_size + 1);
if (dim < offset) {
@@ -279,8 +363,8 @@ class SparseTensor {
// Helper for Split() that returns the shape given a slice index.
static inline int GetSliceShape(const int slice_index, const int split_size,
const int residual) {
- CHECK_GT(split_size, 0);
- CHECK_GE(slice_index, 0);
+ DCHECK_GT(split_size, 0);
+ DCHECK_GE(slice_index, 0);
if (residual == 0) return split_size;
if (slice_index < residual) {
return split_size + 1;
@@ -293,7 +377,7 @@ class SparseTensor {
Tensor vals_;
ShapeArray shape_;
ShapeArray order_;
- const int dims_;
+ int dims_;
};
// This operation updates the indices and values Tensor rows, so it is
@@ -301,9 +385,9 @@ class SparseTensor {
// temporary space.
template <typename T>
void SparseTensor::Reorder(const VarDimArray& order) {
- CHECK_EQ(DataTypeToEnum<T>::v(), dtype())
+ DCHECK_EQ(DataTypeToEnum<T>::v(), dtype())
<< "Reorder requested with the wrong datatype";
- CHECK_EQ(order.size(), dims_) << "Order length must be SparseTensor rank";
+ DCHECK_EQ(order.size(), dims_) << "Order length must be SparseTensor rank";
auto ix_t = ix_.matrix<int64>();
auto vals_t = vals_.vec<T>();
@@ -360,13 +444,13 @@ void SparseTensor::Reorder(const VarDimArray& order) {
template <typename T>
bool SparseTensor::ValidateAndInitializeToDense(Tensor* out, bool initialize) {
- CHECK_EQ(DataTypeToEnum<T>::v(), dtype())
+ DCHECK_EQ(DataTypeToEnum<T>::v(), dtype())
<< "ToDense requested with the wrong datatype";
- CHECK_EQ(out->shape().dims(), dims_)
+ DCHECK_EQ(out->shape().dims(), dims_)
<< "Incompatible dimensions between SparseTensor and output";
- CHECK_EQ(out->dtype(), DataTypeToEnum<T>::v())
+ DCHECK_EQ(out->dtype(), DataTypeToEnum<T>::v())
<< "Output must be type: " << DataTypeToEnum<T>::v()
<< " but got: " << out->dtype();
@@ -422,9 +506,9 @@ bool SparseTensor::ToDense(Tensor* out, bool initialize) {
template <typename T>
SparseTensor SparseTensor::Concat(
const gtl::ArraySlice<SparseTensor>& tensors) {
- CHECK_GE(tensors.size(), size_t{1}) << "Cannot concat 0 SparseTensors";
+ DCHECK_GE(tensors.size(), size_t{1}) << "Cannot concat 0 SparseTensors";
const int dims = tensors[0].dims_;
- CHECK_GE(dims, 1) << "Cannot concat 0-dimensional SparseTensors";
+ DCHECK_GE(dims, 1) << "Cannot concat 0-dimensional SparseTensors";
auto order_0 = tensors[0].order();
const int primary_dim = order_0[0];
ShapeArray final_order(order_0.begin(), order_0.end());
@@ -434,17 +518,17 @@ SparseTensor SparseTensor::Concat(
bool fully_ordered = true;
for (const SparseTensor& st : tensors) {
- CHECK_EQ(st.dims_, dims) << "All SparseTensors must have the same rank.";
- CHECK_EQ(DataTypeToEnum<T>::v(), st.dtype())
+ DCHECK_EQ(st.dims_, dims) << "All SparseTensors must have the same rank.";
+ DCHECK_EQ(DataTypeToEnum<T>::v(), st.dtype())
<< "Concat requested with the wrong data type";
- CHECK_GE(st.order()[0], 0) << "SparseTensor must be ordered";
- CHECK_EQ(st.order()[0], primary_dim)
+ DCHECK_GE(st.order()[0], 0) << "SparseTensor must be ordered";
+ DCHECK_EQ(st.order()[0], primary_dim)
<< "All SparseTensors' order[0] must match. This is the concat dim.";
if (st.order() != final_order) fully_ordered = false;
const VarDimArray& st_shape = st.shape();
for (int d = 0; d < dims - 1; ++d) {
const int cdim = (d < primary_dim) ? d : d + 1;
- CHECK_EQ(final_shape[cdim], st_shape[cdim])
+ DCHECK_EQ(final_shape[cdim], st_shape[cdim])
<< "All SparseTensors' shapes must match except on the concat dim. "
<< "Concat dim: " << primary_dim
<< ", mismatched shape at dim: " << cdim
@@ -494,7 +578,8 @@ SparseTensor SparseTensor::Concat(
template <typename T>
std::vector<SparseTensor> SparseTensor::Split(const SparseTensor& input_tensor,
const int split_dim,
- const int num_split) {
+ const int num_split,
+ Status* status /* = nullptr */) {
std::vector<Tensor> output_indices;
std::vector<Tensor> output_values;
std::vector<TensorShape> output_shapes;
@@ -514,12 +599,18 @@ std::vector<SparseTensor> SparseTensor::Split(const SparseTensor& input_tensor,
const int split_dim_size = input_tensor.shape()[split_dim];
const int split_size = split_dim_size / num_split;
- CHECK(num_split > 0 && num_split <= split_dim_size) << "num_split must be in "
- "the interval (0, "
- << split_dim_size << "]";
- CHECK(split_dim >= 0 && split_dim < num_dim) << "num_dim must be in "
- "the interval [0, "
- << num_dim << ")";
+ if (!(num_split > 0 && num_split <= split_dim_size) && status != nullptr) {
+ *status = Status(error::INVALID_ARGUMENT,
+ strings::StrCat("num_split must be in the interval (0, ",
+ split_dim_size, "]"));
+ return {};
+ }
+ if (!(split_dim >= 0 && split_dim < num_dim)) {
+ *status = Status(
+ error::INVALID_ARGUMENT,
+ strings::StrCat("num_dim must be in the interval [0, ", num_dim, ")"));
+ return {};
+ }
const int residual = split_dim_size % num_split;
for (int i = 0; i < input_tensor.indices().dim_size(0); ++i) {
@@ -559,13 +650,28 @@ std::vector<SparseTensor> SparseTensor::Split(const SparseTensor& input_tensor,
std::vector<SparseTensor> output_tensors;
output_tensors.reserve(num_split);
for (int i = 0; i < num_split; ++i) {
- output_tensors.emplace_back(output_indices[i], output_values[i],
- output_shapes[i]);
+ SparseTensor tensor;
+ Status create_status =
+ Create(output_indices[i], output_values[i], output_shapes[i], &tensor);
+ if (!create_status.ok() && status != nullptr) {
+ *status = create_status;
+ return {};
+ }
+ output_tensors.push_back(std::move(tensor));
}
return output_tensors;
}
template <typename T>
+Status SparseTensor::Split(const SparseTensor& input_tensor,
+ const int split_dim, const int num_split,
+ std::vector<SparseTensor>* result) {
+ Status status;
+ *result = Split<T>(input_tensor, split_dim, num_split, &status);
+ return status;
+}
+
+template <typename T>
SparseTensor SparseTensor::Slice(const SparseTensor& input_tensor,
const gtl::ArraySlice<int64>& start,
const gtl::ArraySlice<int64>& size) {
@@ -643,4 +749,4 @@ SparseTensor SparseTensor::Slice(const SparseTensor& input_tensor,
} // namespace sparse
} // namespace tensorflow
-#endif // TENSORFLOW_UTIL_SPARSE_SPARSE_TENSOR_H_
+#endif // TENSORFLOW_CORE_UTIL_SPARSE_SPARSE_TENSOR_H_
diff --git a/tensorflow/core/util/sparse/sparse_tensor_test.cc b/tensorflow/core/util/sparse/sparse_tensor_test.cc
index 85de032085..5578e42625 100644
--- a/tensorflow/core/util/sparse/sparse_tensor_test.cc
+++ b/tensorflow/core/util/sparse/sparse_tensor_test.cc
@@ -94,9 +94,12 @@ TEST(SparseTensorTest, SparseTensorInvalidIndicesType) {
const int NDIM = 3;
Tensor ix(DT_INT32, TensorShape({N, NDIM}));
Tensor vals(DT_STRING, TensorShape({N}));
+ SparseTensor result;
- EXPECT_DEATH(SparseTensor(ix, vals, TensorShape({10, 10, 10}), {0, 1, 2}),
- "indices must be type int64");
+ EXPECT_EQ(SparseTensor::Create(ix, vals, TensorShape({10, 10, 10}), {0, 1, 2},
+ &result)
+ .code(),
+ error::INVALID_ARGUMENT);
}
TEST(SparseTensorTest, SparseTensorInvalidIndicesShape) {
@@ -104,9 +107,12 @@ TEST(SparseTensorTest, SparseTensorInvalidIndicesShape) {
const int NDIM = 3;
Tensor ix(DT_INT64, TensorShape({N, NDIM, 1}));
Tensor vals(DT_STRING, TensorShape({N}));
+ SparseTensor result;
- EXPECT_DEATH(SparseTensor(ix, vals, TensorShape({10, 10, 10}), {0, 1, 2}),
- "indices must be a matrix");
+ EXPECT_EQ(SparseTensor::Create(ix, vals, TensorShape({10, 10, 10}), {0, 1, 2},
+ &result)
+ .code(),
+ error::INVALID_ARGUMENT);
}
TEST(SparseTensorTest, SparseTensorInvalidValues) {
@@ -114,9 +120,12 @@ TEST(SparseTensorTest, SparseTensorInvalidValues) {
const int NDIM = 3;
Tensor ix(DT_INT64, TensorShape({N, NDIM}));
Tensor vals(DT_STRING, TensorShape({N, 1}));
+ SparseTensor result;
- EXPECT_DEATH(SparseTensor(ix, vals, TensorShape({10, 10, 10}), {0, 1, 2}),
- "vals must be a vec");
+ EXPECT_EQ(SparseTensor::Create(ix, vals, TensorShape({10, 10, 10}), {0, 1, 2},
+ &result)
+ .code(),
+ error::INVALID_ARGUMENT);
}
TEST(SparseTensorTest, SparseTensorInvalidN) {
@@ -124,9 +133,12 @@ TEST(SparseTensorTest, SparseTensorInvalidN) {
const int NDIM = 3;
Tensor ix(DT_INT64, TensorShape({N, NDIM}));
Tensor vals(DT_STRING, TensorShape({N - 1}));
+ SparseTensor result;
- EXPECT_DEATH(SparseTensor(ix, vals, TensorShape({10, 10, 10}), {0, 1, 2}),
- "indices and values rows .* must match");
+ EXPECT_EQ(SparseTensor::Create(ix, vals, TensorShape({10, 10, 10}), {0, 1, 2},
+ &result)
+ .code(),
+ error::INVALID_ARGUMENT);
}
TEST(SparseTensorTest, SparseTensorInvalidOrder) {
@@ -134,18 +146,24 @@ TEST(SparseTensorTest, SparseTensorInvalidOrder) {
const int NDIM = 3;
Tensor ix(DT_INT64, TensorShape({N, NDIM}));
Tensor vals(DT_STRING, TensorShape({N}));
+ SparseTensor result;
- EXPECT_DEATH(SparseTensor(ix, vals, TensorShape({10, 10, 10}), {0, 1}),
- "Order length must be SparseTensor rank");
+ EXPECT_EQ(
+ SparseTensor::Create(ix, vals, TensorShape({10, 10, 10}), {0, 1}, &result)
+ .code(),
+ error::INVALID_ARGUMENT);
}
TEST(SparseTensorTest, SparseTensorInvalidShape) {
int N = 5;
const int NDIM = 3;
Tensor ix(DT_INT64, TensorShape({N, NDIM}));
Tensor vals(DT_STRING, TensorShape({N}));
+ SparseTensor result;
- EXPECT_DEATH(SparseTensor(ix, vals, TensorShape({10, 10}), {0, 1, 2}),
- "Shape rank must be SparseTensor rank");
+ EXPECT_EQ(
+ SparseTensor::Create(ix, vals, TensorShape({10, 10}), {0, 1, 2}, &result)
+ .code(),
+ error::INVALID_ARGUMENT);
}
TEST(SparseTensorTest, SparseTensorConstruction) {
@@ -169,7 +187,8 @@ TEST(SparseTensorTest, SparseTensorConstruction) {
TensorShape shape({10, 10, 10});
std::vector<int64> order{0, 1, 2};
- SparseTensor st(ix, vals, shape, order);
+ SparseTensor st;
+ TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st));
Status st_indices_valid = st.IndicesValid();
EXPECT_FALSE(st_indices_valid.ok());
EXPECT_EQ("indices[2] = [2,0,0] is out of order",
@@ -210,7 +229,8 @@ TEST(SparseTensorTest, EmptySparseTensorAllowed) {
std::vector<int64> shape{10, 10, 10};
std::vector<int64> order{0, 1, 2};
- SparseTensor st(ix, vals, shape, order);
+ SparseTensor st;
+ TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st));
TF_EXPECT_OK(st.IndicesValid());
EXPECT_EQ(st.order(), order);
@@ -227,7 +247,8 @@ TEST(SparseTensorTest, SortingWorksCorrectly) {
Tensor ix(DT_INT64, TensorShape({N, NDIM}));
Tensor vals(DT_STRING, TensorShape({N}));
TensorShape shape({1000, 1000, 1000, 1000});
- SparseTensor st(ix, vals, shape);
+ SparseTensor st;
+ TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, &st));
auto ix_t = ix.matrix<int64>();
@@ -266,7 +287,8 @@ TEST(SparseTensorTest, ValidateIndicesFindsInvalid) {
TensorShape shape({10, 10, 10});
std::vector<int64> order{0, 1, 2};
- SparseTensor st(ix, vals, shape, order);
+ SparseTensor st;
+ TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st));
st.Reorder<string>(order);
Status st_indices_valid = st.IndicesValid();
@@ -302,7 +324,8 @@ TEST(SparseTensorTest, SparseTensorCheckBoundaries) {
TensorShape shape({10, 10, 10});
std::vector<int64> order{0, 1, 2};
- SparseTensor st(ix, vals, shape, order);
+ SparseTensor st;
+ TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st));
EXPECT_FALSE(st.IndicesValid().ok());
st.Reorder<string>(order);
@@ -351,7 +374,8 @@ TEST(SparseTensorTest, SparseTensorToDenseTensor) {
TensorShape shape({4, 4, 5});
std::vector<int64> order{0, 1, 2};
- SparseTensor st(ix, vals, shape, order);
+ SparseTensor st;
+ TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st));
Tensor dense(DT_STRING, TensorShape({4, 4, 5}));
st.ToDense<string>(&dense);
@@ -390,7 +414,8 @@ TEST(SparseTensorTest, SparseTensorToLargerDenseTensor) {
TensorShape shape({4, 4, 5});
std::vector<int64> order{0, 1, 2};
- SparseTensor st(ix, vals, shape, order);
+ SparseTensor st;
+ TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st));
Tensor dense(DT_STRING, TensorShape({10, 10, 10}));
st.ToDense<string>(&dense);
@@ -433,7 +458,8 @@ TEST(SparseTensorTest, SparseTensorGroup) {
TensorShape shape({10, 10, 10});
std::vector<int64> order{0, 1, 2};
- SparseTensor st(ix, vals, shape, order);
+ SparseTensor st;
+ TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st));
st.Reorder<int32>(order);
std::vector<std::vector<int64> > groups;
@@ -521,7 +547,8 @@ TEST(SparseTensorTest, Concat) {
TensorShape shape({10, 10, 10});
std::vector<int64> order{0, 1, 2};
- SparseTensor st(ix, vals, shape, order);
+ SparseTensor st;
+ TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st));
EXPECT_FALSE(st.IndicesValid().ok());
st.Reorder<string>(order);
TF_EXPECT_OK(st.IndicesValid());
@@ -551,7 +578,9 @@ TEST(SparseTensorTest, Concat) {
// Concat works if non-primary ix is out of order, but output order
// is not defined
- SparseTensor st_ooo(ix, vals, shape, {0, 2, 1}); // non-primary ix OOO
+ SparseTensor st_ooo;
+ TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, {0, 2, 1},
+ &st_ooo)); // non-primary ix OOO
SparseTensor conc_ooo = SparseTensor::Concat<string>({st, st, st, st_ooo});
std::vector<int64> expected_ooo{-1, -1, -1};
EXPECT_EQ(conc_ooo.order(), expected_ooo);
@@ -584,9 +613,11 @@ TEST(SparseTensorTest, Split) {
vals.vec<int64>()(2) = 3;
vals.vec<int64>()(3) = 4;
- SparseTensor st(ids, vals, TensorShape({4, 3}));
+ SparseTensor st;
+ TF_ASSERT_OK(SparseTensor::Create(ids, vals, TensorShape({4, 3}), &st));
- std::vector<SparseTensor> st_list = SparseTensor::Split<int64>(st, 0, 2);
+ std::vector<SparseTensor> st_list;
+ TF_ASSERT_OK(SparseTensor::Split<int64>(st, 0, 2, &st_list));
EXPECT_EQ(st_list.size(), 2);
auto expected_shape = gtl::InlinedVector<int64, 8>{2, 3};
@@ -633,7 +664,8 @@ TEST(SparseTensorTest, Slice) {
vals.vec<int64>()(2) = 3;
vals.vec<int64>()(3) = 4;
- SparseTensor st(ids, vals, TensorShape({4, 3}));
+ SparseTensor st;
+ TF_ASSERT_OK(SparseTensor::Create(ids, vals, TensorShape({4, 3}), &st));
std::vector<int64> start(2, 0);
std::vector<int64> size(2);
@@ -662,7 +694,8 @@ TEST(SparseTensorTest, Dim0SparseTensorToDenseTensor) {
vals.scalar<int32>()() = 5;
TensorShape shape({});
- SparseTensor st(ix, vals, shape);
+ SparseTensor st;
+ TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, &st));
Tensor dense(DT_INT32, TensorShape({}));
st.ToDense<int32>(&dense);
@@ -699,7 +732,8 @@ static void BM_SparseReorderFloat(int iters, int N32, int NDIM32) {
ix_t(i, d) = rnd.Rand64() % 1000;
}
}
- SparseTensor st(ix, vals, shape, order);
+ SparseTensor st;
+ TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st));
testing::StartTiming();
st.Reorder<float>(reorder);
@@ -740,7 +774,8 @@ static void BM_SparseReorderString(int iters, int N32, int NDIM32) {
ix_t(i, d) = rnd.Rand64() % 1000;
}
}
- SparseTensor st(ix, vals, shape, order);
+ SparseTensor st;
+ TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st));
testing::StartTiming();
st.Reorder<string>(reorder);
diff --git a/tensorflow/core/util/tensor_format.cc b/tensorflow/core/util/tensor_format.cc
index 33ab87aa78..a5f7ecf0d1 100644
--- a/tensorflow/core/util/tensor_format.cc
+++ b/tensorflow/core/util/tensor_format.cc
@@ -18,7 +18,7 @@ limitations under the License.
namespace tensorflow {
string GetConvnetDataFormatAttrString() {
- return "data_format: { 'NHWC', 'NCHW', 'HWNC', 'HWCN' } = 'NHWC' ";
+ return "data_format: { 'NHWC', 'NCHW' } = 'NHWC' ";
}
string GetConvnet3dDataFormatAttrString() {
diff --git a/tensorflow/docs_src/deploy/s3.md b/tensorflow/docs_src/deploy/s3.md
index 9ef9674338..7028249e94 100644
--- a/tensorflow/docs_src/deploy/s3.md
+++ b/tensorflow/docs_src/deploy/s3.md
@@ -90,4 +90,4 @@ S3 was invented by Amazon, but the S3 API has spread in popularity and has sever
* [Amazon S3](https://aws.amazon.com/s3/)
* [Google Storage](https://cloud.google.com/storage/docs/interoperability)
-* [Minio](https://www.minio.io/kubernetes.html)(Standalone mode only)
+* [Minio](https://www.minio.io/kubernetes.html)
diff --git a/tensorflow/docs_src/extend/index.md b/tensorflow/docs_src/extend/index.md
index 1ab0340ad9..d48340a777 100644
--- a/tensorflow/docs_src/extend/index.md
+++ b/tensorflow/docs_src/extend/index.md
@@ -17,7 +17,8 @@ TensorFlow:
Python is currently the only language supported by TensorFlow's API stability
promises. However, TensorFlow also provides functionality in C++, Go, Java and
-[JavaScript](https://js.tensorflow.org),
+[JavaScript](https://js.tensorflow.org) (incuding
+[Node.js](https://github.com/tensorflow/tfjs-node)),
plus community support for [Haskell](https://github.com/tensorflow/haskell) and
[Rust](https://github.com/tensorflow/rust). If you'd like to create or
develop TensorFlow features in a language other than these languages, read the
diff --git a/tensorflow/docs_src/extend/new_data_formats.md b/tensorflow/docs_src/extend/new_data_formats.md
index d1d1f69766..abbf47910e 100644
--- a/tensorflow/docs_src/extend/new_data_formats.md
+++ b/tensorflow/docs_src/extend/new_data_formats.md
@@ -77,18 +77,24 @@ can be used as a starting point for your implementation:
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
-namespace tensorflow {
+namespace myproject {
namespace {
-class MyReaderDatasetOp : public DatasetOpKernel {
+using ::tensorflow::DT_STRING;
+using ::tensorflow::PartialTensorShape;
+using ::tensorflow::Status;
+
+class MyReaderDatasetOp : public tensorflow::DatasetOpKernel {
public:
- MyReaderDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {
+ MyReaderDatasetOp(tensorflow::OpKernelConstruction* ctx)
+ : DatasetOpKernel(ctx) {
// Parse and validate any attrs that define the dataset using
// `ctx->GetAttr()`, and store them in member variables.
}
- void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
+ void MakeDataset(tensorflow::OpKernelContext* ctx,
+ tensorflow::DatasetBase** output) override {
// Parse and validate any input tensors 0that define the dataset using
// `ctx->input()` or the utility function
// `ParseScalarArgument<T>(ctx, &arg)`.
@@ -99,14 +105,14 @@ class MyReaderDatasetOp : public DatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public tensorflow::GraphDatasetBase {
public:
- Dataset(OpKernelContext* ctx) : GraphDatasetBase(ctx) {}
+ Dataset(tensorflow::OpKernelContext* ctx) : GraphDatasetBase(ctx) {}
- std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ std::unique_ptr<tensorflow::IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
- return std::unique_ptr<IteratorBase>(
- new Iterator({this, strings::StrCat(prefix, "::MyReader")}));
+ return std::unique_ptr<tensorflow::IteratorBase>(new Iterator(
+ {this, tensorflow::strings::StrCat(prefix, "::MyReader")}));
}
// Record structure: Each record is represented by a scalar string tensor.
@@ -114,8 +120,8 @@ class MyReaderDatasetOp : public DatasetOpKernel {
// Dataset elements can have a fixed number of components of different
// types and shapes; replace the following two methods to customize this
// aspect of the dataset.
- const DataTypeVector& output_dtypes() const override {
- static DataTypeVector* dtypes = new DataTypeVector({DT_STRING});
+ const tensorflow::DataTypeVector& output_dtypes() const override {
+ static auto* const dtypes = new tensorflow::DataTypeVector({DT_STRING});
return *dtypes;
}
const std::vector<PartialTensorShape>& output_shapes() const override {
@@ -132,16 +138,16 @@ class MyReaderDatasetOp : public DatasetOpKernel {
// Implement this method if you want to be able to save and restore
// instances of this dataset (and any iterators over it).
Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
- Node** output) const override {
+ tensorflow::Node** output) const override {
// Construct nodes to represent any of the input tensors from this
// object's member variables using `b->AddScalar()` and `b->AddVector()`.
- std::vector<Node*> input_tensors;
+ std::vector<tensorflow::Node*> input_tensors;
TF_RETURN_IF_ERROR(b->AddDataset(this, input_tensors, output));
return Status::OK();
}
private:
- class Iterator : public DatasetIterator<Dataset> {
+ class Iterator : public tensorflow::DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params), i_(0) {}
@@ -158,15 +164,15 @@ class MyReaderDatasetOp : public DatasetOpKernel {
// return `Status::OK()`.
// 3. If an error occurs, return an error status using one of the helper
// functions from "tensorflow/core/lib/core/errors.h".
- Status GetNextInternal(IteratorContext* ctx,
- std::vector<Tensor>* out_tensors,
+ Status GetNextInternal(tensorflow::IteratorContext* ctx,
+ std::vector<tensorflow::Tensor>* out_tensors,
bool* end_of_sequence) override {
// NOTE: `GetNextInternal()` may be called concurrently, so it is
// recommended that you protect the iterator state with a mutex.
- mutex_lock l(mu_);
+ tensorflow::mutex_lock l(mu_);
if (i_ < 10) {
// Create a scalar string tensor and add it to the output.
- Tensor record_tensor(ctx->allocator({}), DT_STRING, {});
+ tensorflow::Tensor record_tensor(ctx->allocator({}), DT_STRING, {});
record_tensor.scalar<string>()() = "MyReader!";
out_tensors->emplace_back(std::move(record_tensor));
++i_;
@@ -183,20 +189,20 @@ class MyReaderDatasetOp : public DatasetOpKernel {
//
// Implement these two methods if you want to be able to save and restore
// instances of this iterator.
- Status SaveInternal(IteratorStateWriter* writer) override {
- mutex_lock l(mu_);
+ Status SaveInternal(tensorflow::IteratorStateWriter* writer) override {
+ tensorflow::mutex_lock l(mu_);
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_));
return Status::OK();
}
- Status RestoreInternal(IteratorContext* ctx,
- IteratorStateReader* reader) override {
- mutex_lock l(mu_);
+ Status RestoreInternal(tensorflow::IteratorContext* ctx,
+ tensorflow::IteratorStateReader* reader) override {
+ tensorflow::mutex_lock l(mu_);
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_));
return Status::OK();
}
private:
- mutex mu_;
+ tensorflow::mutex mu_;
int64 i_ GUARDED_BY(mu_);
};
};
@@ -211,14 +217,14 @@ class MyReaderDatasetOp : public DatasetOpKernel {
REGISTER_OP("MyReaderDataset")
.Output("handle: variant")
.SetIsStateful()
- .SetShapeFn(shape_inference::ScalarShape);
+ .SetShapeFn(tensorflow::shape_inference::ScalarShape);
// Register the kernel implementation for MyReaderDataset.
-REGISTER_KERNEL_BUILDER(Name("MyReaderDataset").Device(DEVICE_CPU),
+REGISTER_KERNEL_BUILDER(Name("MyReaderDataset").Device(tensorflow::DEVICE_CPU),
MyReaderDatasetOp);
} // namespace
-} // namespace tensorflow
+} // namespace myproject
```
The last step is to build the C++ code and add a Python wrapper. The easiest way
diff --git a/tensorflow/docs_src/guide/eager.md b/tensorflow/docs_src/guide/eager.md
index 003ca265fe..e98206eef9 100644
--- a/tensorflow/docs_src/guide/eager.md
+++ b/tensorflow/docs_src/guide/eager.md
@@ -421,7 +421,7 @@ class Model(tf.keras.Model):
super(Model, self).__init__()
self.W = tfe.Variable(5., name='weight')
self.B = tfe.Variable(10., name='bias')
- def predict(self, inputs):
+ def call(self, inputs):
return inputs * self.W + self.B
# A toy dataset of points around 3 * x + 2
@@ -432,7 +432,7 @@ training_outputs = training_inputs * 3 + 2 + noise
# The loss function to be optimized
def loss(model, inputs, targets):
- error = model.predict(inputs) - targets
+ error = model(inputs) - targets
return tf.reduce_mean(tf.square(error))
def grad(model, inputs, targets):
diff --git a/tensorflow/docs_src/guide/feature_columns.md b/tensorflow/docs_src/guide/feature_columns.md
index 1013ec910c..41080e050b 100644
--- a/tensorflow/docs_src/guide/feature_columns.md
+++ b/tensorflow/docs_src/guide/feature_columns.md
@@ -561,9 +561,9 @@ For more examples on feature columns, view the following:
* The @{$low_level_intro#feature_columns$Low Level Introduction} demonstrates how
experiment directly with `feature_columns` using TensorFlow's low level APIs.
-* The @{$wide$wide} and @{$wide_and_deep$Wide & Deep} Tutorials solve a
- binary classification problem using `feature_columns` on a variety of input
- data types.
+* The [Estimator wide and deep learning tutorial](https://github.com/tensorflow/models/tree/master/official/wide_deep)
+ solves a binary classification problem using `feature_columns` on a variety of
+ input data types.
To learn more about embeddings, see the following:
diff --git a/tensorflow/docs_src/guide/graph_viz.md b/tensorflow/docs_src/guide/graph_viz.md
index f581ae56da..a8876da5a5 100644
--- a/tensorflow/docs_src/guide/graph_viz.md
+++ b/tensorflow/docs_src/guide/graph_viz.md
@@ -248,7 +248,8 @@ The images below show the CIFAR-10 model with tensor shape information:
Often it is useful to collect runtime metadata for a run, such as total memory
usage, total compute time, and tensor shapes for nodes. The code example below
is a snippet from the train and test section of a modification of the
-@{$layers$simple MNIST tutorial}, in which we have recorded summaries and
+[Estimators MNIST tutorial](../tutorials/estimators/cnn.md), in which we have
+recorded summaries and
runtime statistics. See the
@{$summaries_and_tensorboard#serializing-the-data$Summaries Tutorial}
for details on how to record summaries.
diff --git a/tensorflow/docs_src/guide/index.md b/tensorflow/docs_src/guide/index.md
index eefdb9ceae..f78dfc9a89 100644
--- a/tensorflow/docs_src/guide/index.md
+++ b/tensorflow/docs_src/guide/index.md
@@ -16,15 +16,12 @@ works. The units are as follows:
## Estimators
-* @{$estimators} provides an introduction.
-* @{$premade_estimators}, introduces Estimators for machine learning.
-* @{$custom_estimators}, which demonstrates how to build and train models you
- design yourself.
-* @{$feature_columns}, which shows how an Estimator can handle a variety of input
- data types without changes to the model.
-* @{$datasets_for_estimators} describes using tf.data with estimators.
-* @{$checkpoints}, which explains how to save training progress and resume where
- you left off.
+* @{$estimators}, learn how to use Estimators for machine learning.
+* @{$premade_estimators}, the basics of premade Estimators.
+* @{$checkpoints}, save training progress and resume where you left off.
+* @{$feature_columns}, handle a variety of input data types without changes to the model.
+* @{$datasets_for_estimators}, use `tf.data` to input data.
+* @{$custom_estimators}, write your own Estimator.
## Accelerators
diff --git a/tensorflow/docs_src/guide/leftnav_files b/tensorflow/docs_src/guide/leftnav_files
index 357a2a1cb9..b3324278c1 100644
--- a/tensorflow/docs_src/guide/leftnav_files
+++ b/tensorflow/docs_src/guide/leftnav_files
@@ -8,10 +8,10 @@ datasets.md
### Estimators
estimators.md: Introduction to Estimators
premade_estimators.md
-custom_estimators.md
+checkpoints.md
feature_columns.md
datasets_for_estimators.md
-checkpoints.md
+custom_estimators.md
### Accelerators
using_gpu.md
diff --git a/tensorflow/docs_src/install/install_linux.md b/tensorflow/docs_src/install/install_linux.md
index f21c073a1b..541a55e184 100644
--- a/tensorflow/docs_src/install/install_linux.md
+++ b/tensorflow/docs_src/install/install_linux.md
@@ -511,6 +511,8 @@ on your system:
list of supported GPU cards.
* [GPU drivers](http://nvidia.com/drivers) that support your version of the CUDA
Toolkit.
+* NCCL 2.2 to use TensorFlow with multiple GPUs. For details, see [NVIDIA's
+ documentation](https://developer.nvidia.com/nccl).
* The `libcupti-dev` library is the NVIDIA CUDA Profile Tools Interface. This
library provides advanced profiling support. To install this library,
use the following command for CUDA Toolkit >= 8.0:
diff --git a/tensorflow/docs_src/javascript/index.md b/tensorflow/docs_src/javascript/index.md
deleted file mode 100644
index ad63eeb255..0000000000
--- a/tensorflow/docs_src/javascript/index.md
+++ /dev/null
@@ -1,5 +0,0 @@
-# JavaScript
-
-You may develop TensorFlow programs in JavaScript, training and deploying
-models right in your browser. For details, see
-[js.tensorflow.org](https://js.tensorflow.org).
diff --git a/tensorflow/docs_src/javascript/leftnav_files b/tensorflow/docs_src/javascript/leftnav_files
deleted file mode 100644
index fc0ab8a543..0000000000
--- a/tensorflow/docs_src/javascript/leftnav_files
+++ /dev/null
@@ -1 +0,0 @@
-index.md
diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md
index 4c4f3f3934..68c427a316 100644
--- a/tensorflow/docs_src/performance/xla/operation_semantics.md
+++ b/tensorflow/docs_src/performance/xla/operation_semantics.md
@@ -2015,30 +2015,37 @@ two-operand version.
<b>`Sort(operand)`</b>
-Arguments | Type | Semantics
---------- | ------- | --------------------
-`operand` | `XlaOp` | The operand to sort.
-
-Sorts the elements in the operand in ascending order. The operand must be rank-1.
-If the operand's elements have floating point type, and the operand contains
-NaN elements, the order of elements in the output is implementation-defined.
+Arguments | Type | Semantics
+----------- | ------- | --------------------
+`operand` | `XlaOp` | The operand to sort.
+`dimension` | `int64` | The dimension along which to sort.
+
+Sorts the elements in the operand in ascending order along the provided
+dimension. For example, for a rank-2 (matrix) operand, a `dimension` value of 0
+will sort each column independently, and a `dimension` value of 1 will sort each
+row independently. If the operand's elements have floating point type, and the
+operand contains NaN elements, the order of elements in the output is
+implementation-defined.
<b>`Sort(key, value)`</b>
Sorts both the key and the value operands. The keys are sorted as in the
single-operand version. The values are sorted according to the order of their
corresponding keys. For example, if the inputs are `keys = [3, 1]` and
-`values = [42, 50]`, then the output of the sort is the tuple `{[1, 3], [50, 42]}`.
+`values = [42, 50]`, then the output of the sort is the tuple
+`{[1, 3], [50, 42]}`.
+
The sort is not guaranteed to be stable, that is, if the keys array contains
duplicates, the order of their corresponding values may not be preserved.
-Arguments | Type | Semantics
---------- | ------- | -------------------
-`keys` | `XlaOp` | The sort keys.
-`values` | `XlaOp` | The values to sort.
+Arguments | Type | Semantics
+----------- | ------- | -------------------
+`keys` | `XlaOp` | The sort keys.
+`values` | `XlaOp` | The values to sort.
+`dimension` | `int64` | The dimension along which to sort.
-The `keys` and `values` operand must both be rank-1, and must have the same
-dimensions, but may have different element types.
+The `keys` and `values` must have the same dimensions, but may have different
+element types.
## Transpose
diff --git a/tensorflow/docs_src/tutorials/_index.yaml b/tensorflow/docs_src/tutorials/_index.yaml
index 6fc8155669..07d561b8a2 100644
--- a/tensorflow/docs_src/tutorials/_index.yaml
+++ b/tensorflow/docs_src/tutorials/_index.yaml
@@ -170,15 +170,16 @@ landing_page:
<div class="devsite-landing-row-item-description-content">
<p>
Estimators can train large models on multiple machines in a
- production environment. Read the
- <a href="/guide/estimators">Estimators guide</a> for details.
+ production environment. TensorFlow provides a collection of
+ pre-made Estimators to implement common ML algorithms. See the
+ <a href="/guide/estimators">Estimators guide</a>.
</p>
<ol style="padding-left: 20px;">
- <li><a href="/tutorials/images/layers">Build a Convolutional Neural Network using Estimators</a></li>
+ <li><a href="/guide/premade_estimators">Premade Estimators guide</a></li>
+ <li><a href="https://github.com/tensorflow/models/tree/master/official/wide_deep" class="external">Wide and deep learning with Estimators</a></li>
+ <li><a href="https://github.com/tensorflow/models/tree/master/official/boosted_trees" class="external">Boosted trees</a></li>
<li><a href="/hub/tutorials/text_classification_with_tf_hub">How to build a simple text classifier with TF-Hub</a></li>
- <li><a href="https://github.com/tensorflow/models/tree/master/official/boosted_trees">Classifying Higgs boson processes</a></li>
- <li><a href="/tutorials/representation/wide_and_deep">Wide and deep learning using Estimators</a></li>
- <li><a href="/tutorials/representation/linear">Large-scale linear models</a></li>
+ <li><a href="/tutorials/estimators/cnn">Build a Convolutional Neural Network using Estimators</a></li>
</ol>
</div>
<div class="devsite-landing-row-item-buttons">
diff --git a/tensorflow/docs_src/tutorials/_toc.yaml b/tensorflow/docs_src/tutorials/_toc.yaml
index d46d570a93..4db97e35fc 100644
--- a/tensorflow/docs_src/tutorials/_toc.yaml
+++ b/tensorflow/docs_src/tutorials/_toc.yaml
@@ -24,7 +24,7 @@ toc:
- title: Overview
path: /tutorials/eager/
- title: Eager execution
- path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/eager_intro.ipynb
+ path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/eager_basics.ipynb
status: external
- title: Automatic differentiation
path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb
@@ -37,15 +37,27 @@ toc:
status: external
- title: "Custom training: walkthrough"
path: /tutorials/eager/custom_training_walkthrough
- - title: Neural machine translation
+ - title: Translation with attention
path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb
status: external
-- title: Images
+- title: ML at production scale
style: accordion
section:
+ - title: Wide and deep learning
+ path: https://github.com/tensorflow/models/tree/master/official/wide_deep
+ status: external
+ - title: Boosted trees
+ path: https://github.com/tensorflow/models/tree/master/official/boosted_trees
+ status: external
+ - title: Text classifier with TF-Hub
+ path: /hub/tutorials/text_classification_with_tf_hub
- title: Build a CNN using Estimators
- path: /tutorials/images/layers
+ path: /tutorials/estimators/cnn
+
+- title: Images
+ style: accordion
+ section:
- title: Image recognition
path: /tutorials/images/image_recognition
- title: Image retraining
@@ -69,10 +81,6 @@ toc:
- title: Data representation
style: accordion
section:
- - title: Linear models
- path: /tutorials/representation/wide
- - title: Wide and deep learning
- path: /tutorials/representation/wide_and_deep
- title: Vector representations of words
path: /tutorials/representation/word2vec
- title: Kernel methods
diff --git a/tensorflow/docs_src/tutorials/images/layers.md b/tensorflow/docs_src/tutorials/estimators/cnn.md
index 12a215b50c..12a215b50c 100644
--- a/tensorflow/docs_src/tutorials/images/layers.md
+++ b/tensorflow/docs_src/tutorials/estimators/cnn.md
diff --git a/tensorflow/docs_src/tutorials/images/deep_cnn.md b/tensorflow/docs_src/tutorials/images/deep_cnn.md
index 1590f15eb9..27963575f5 100644
--- a/tensorflow/docs_src/tutorials/images/deep_cnn.md
+++ b/tensorflow/docs_src/tutorials/images/deep_cnn.md
@@ -80,21 +80,21 @@ for details. It consists of 1,068,298 learnable parameters and requires about
## Code Organization
The code for this tutorial resides in
-[`models/tutorials/image/cifar10/`](https://www.tensorflow.org/code/tensorflow_models/tutorials/image/cifar10/).
+[`models/tutorials/image/cifar10/`](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10/).
File | Purpose
--- | ---
-[`cifar10_input.py`](https://www.tensorflow.org/code/tensorflow_models/tutorials/image/cifar10/cifar10_input.py) | Reads the native CIFAR-10 binary file format.
-[`cifar10.py`](https://www.tensorflow.org/code/tensorflow_models/tutorials/image/cifar10/cifar10.py) | Builds the CIFAR-10 model.
-[`cifar10_train.py`](https://www.tensorflow.org/code/tensorflow_models/tutorials/image/cifar10/cifar10_train.py) | Trains a CIFAR-10 model on a CPU or GPU.
-[`cifar10_multi_gpu_train.py`](https://www.tensorflow.org/code/tensorflow_models/tutorials/image/cifar10/cifar10_multi_gpu_train.py) | Trains a CIFAR-10 model on multiple GPUs.
-[`cifar10_eval.py`](https://www.tensorflow.org/code/tensorflow_models/tutorials/image/cifar10/cifar10_eval.py) | Evaluates the predictive performance of a CIFAR-10 model.
+[`cifar10_input.py`](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10/cifar10_input.py) | Reads the native CIFAR-10 binary file format.
+[`cifar10.py`](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10/cifar10.py) | Builds the CIFAR-10 model.
+[`cifar10_train.py`](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10/cifar10_train.py) | Trains a CIFAR-10 model on a CPU or GPU.
+[`cifar10_multi_gpu_train.py`](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10/cifar10_multi_gpu_train.py) | Trains a CIFAR-10 model on multiple GPUs.
+[`cifar10_eval.py`](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10/cifar10_eval.py) | Evaluates the predictive performance of a CIFAR-10 model.
## CIFAR-10 Model
The CIFAR-10 network is largely contained in
-[`cifar10.py`](https://www.tensorflow.org/code/tensorflow_models/tutorials/image/cifar10/cifar10.py).
+[`cifar10.py`](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10/cifar10.py).
The complete training
graph contains roughly 765 operations. We find that we can make the code most
reusable by constructing the graph with the following modules:
diff --git a/tensorflow/docs_src/tutorials/images/image_recognition.md b/tensorflow/docs_src/tutorials/images/image_recognition.md
index 432d470d0c..d545de73df 100644
--- a/tensorflow/docs_src/tutorials/images/image_recognition.md
+++ b/tensorflow/docs_src/tutorials/images/image_recognition.md
@@ -449,7 +449,7 @@ covering them.
To find out more about implementing convolutional neural networks, you can jump
to the TensorFlow @{$deep_cnn$deep convolutional networks tutorial},
-or start a bit more gently with our @{$layers$MNIST starter tutorial}.
+or start a bit more gently with our [Estimator MNIST tutorial](../estimators/cnn.md).
Finally, if you want to get up to speed on research in this area, you can
read the recent work of all the papers referenced in this tutorial.
diff --git a/tensorflow/docs_src/tutorials/representation/linear.md b/tensorflow/docs_src/tutorials/representation/linear.md
index 3f247ade26..1b418cf065 100644
--- a/tensorflow/docs_src/tutorials/representation/linear.md
+++ b/tensorflow/docs_src/tutorials/representation/linear.md
@@ -11,8 +11,9 @@ those tools. It explains:
deep learning to get the advantages of both.
Read this overview to decide whether the Estimator's linear model tools might
-be useful to you. Then do the @{$wide$Linear Models tutorial} to
-give it a try. This overview uses code samples from the tutorial, but the
+be useful to you. Then work through the
+[Estimator wide and deep learning tutorial](https://github.com/tensorflow/models/tree/master/official/wide_deep)
+to give it a try. This overview uses code samples from the tutorial, but the
tutorial walks through the code in greater detail.
To understand this overview it will help to have some familiarity
@@ -176,7 +177,7 @@ the name of a `FeatureColumn`. Each key's value is a tensor containing the
values of that feature for all data instances. See
@{$premade_estimators#input_fn} for a
more comprehensive look at input functions, and `input_fn` in the
-[linear models tutorial code](https://github.com/tensorflow/models/tree/master/official/wide_deep/wide_deep.py)
+[wide and deep learning tutorial](https://github.com/tensorflow/models/tree/master/official/wide_deep)
for an example implementation of an input function.
The input function is passed to the `train()` and `evaluate()` calls that
@@ -234,4 +235,5 @@ e = tf.estimator.DNNLinearCombinedClassifier(
dnn_feature_columns=deep_columns,
dnn_hidden_units=[100, 50])
```
-For more information, see the @{$wide_and_deep$Wide and Deep Learning tutorial}.
+For more information, see the
+[wide and deep learning tutorial](https://github.com/tensorflow/models/tree/master/official/wide_deep).
diff --git a/tensorflow/docs_src/tutorials/representation/wide.md b/tensorflow/docs_src/tutorials/representation/wide.md
deleted file mode 100644
index 27ce75a30d..0000000000
--- a/tensorflow/docs_src/tutorials/representation/wide.md
+++ /dev/null
@@ -1,461 +0,0 @@
-# TensorFlow Linear Model Tutorial
-
-In this tutorial, we will use the tf.estimator API in TensorFlow to solve a
-binary classification problem: Given census data about a person such as age,
-education, marital status, and occupation (the features), we will try to predict
-whether or not the person earns more than 50,000 dollars a year (the target
-label). We will train a **logistic regression** model, and given an individual's
-information our model will output a number between 0 and 1, which can be
-interpreted as the probability that the individual has an annual income of over
-50,000 dollars.
-
-## Setup
-
-To try the code for this tutorial:
-
-1. @{$install$Install TensorFlow} if you haven't already.
-
-2. Download [the tutorial code](https://github.com/tensorflow/models/tree/master/official/wide_deep/).
-
-3. Execute the data download script we provide to you:
-
- $ python data_download.py
-
-4. Execute the tutorial code with the following command to train the linear
-model described in this tutorial:
-
- $ python wide_deep.py --model_type=wide
-
-Read on to find out how this code builds its linear model.
-
-## Reading The Census Data
-
-The dataset we'll be using is the
-[Census Income Dataset](https://archive.ics.uci.edu/ml/datasets/Census+Income).
-We have provided
-[data_download.py](https://github.com/tensorflow/models/tree/master/official/wide_deep/data_download.py)
-which downloads the code and performs some additional cleanup.
-
-Since the task is a binary classification problem, we'll construct a label
-column named "label" whose value is 1 if the income is over 50K, and 0
-otherwise. For reference, see `input_fn` in
-[wide_deep.py](https://github.com/tensorflow/models/tree/master/official/wide_deep/wide_deep.py).
-
-Next, let's take a look at the dataframe and see which columns we can use to
-predict the target label. The columns can be grouped into two types—categorical
-and continuous columns:
-
-* A column is called **categorical** if its value can only be one of the
- categories in a finite set. For example, the relationship status of a person
- (wife, husband, unmarried, etc.) or the education level (high school,
- college, etc.) are categorical columns.
-* A column is called **continuous** if its value can be any numerical value in
- a continuous range. For example, the capital gain of a person (e.g. $14,084)
- is a continuous column.
-
-Here's a list of columns available in the Census Income dataset:
-
-| Column Name | Type | Description |
-| -------------- | ----------- | --------------------------------- |
-| age | Continuous | The age of the individual |
-| workclass | Categorical | The type of employer the |
-: : : individual has (government, :
-: : : military, private, etc.). :
-| fnlwgt | Continuous | The number of people the census |
-: : : takers believe that observation :
-: : : represents (sample weight). Final :
-: : : weight will not be used. :
-| education | Categorical | The highest level of education |
-: : : achieved for that individual. :
-| education_num | Continuous | The highest level of education in |
-: : : numerical form. :
-| marital_status | Categorical | Marital status of the individual. |
-| occupation | Categorical | The occupation of the individual. |
-| relationship | Categorical | Wife, Own-child, Husband, |
-: : : Not-in-family, Other-relative, :
-: : : Unmarried. :
-| race | Categorical | Amer-Indian-Eskimo, Asian-Pac- |
-: : : Islander, Black, White, Other. :
-| gender | Categorical | Female, Male. |
-| capital_gain | Continuous | Capital gains recorded. |
-| capital_loss | Continuous | Capital Losses recorded. |
-| hours_per_week | Continuous | Hours worked per week. |
-| native_country | Categorical | Country of origin of the |
-: : : individual. :
-| income_bracket | Categorical | ">50K" or "<=50K", meaning |
-: : : whether the person makes more :
-: : : than $50,000 annually. :
-
-## Converting Data into Tensors
-
-When building a tf.estimator model, the input data is specified by means of an
-Input Builder function. This builder function will not be called until it is
-later passed to tf.estimator.Estimator methods such as `train` and `evaluate`.
-The purpose of this function is to construct the input data, which is
-represented in the form of @{tf.Tensor}s or @{tf.SparseTensor}s.
-In more detail, the input builder function returns the following as a pair:
-
-1. `features`: A dict from feature column names to `Tensors` or
- `SparseTensors`.
-2. `labels`: A `Tensor` containing the label column.
-
-The keys of the `features` will be used to construct columns in the next
-section. Because we want to call the `train` and `evaluate` methods with
-different data, we define a method that returns an input function based on the
-given data. Note that the returned input function will be called while
-constructing the TensorFlow graph, not while running the graph. What it is
-returning is a representation of the input data as the fundamental unit of
-TensorFlow computations, a `Tensor` (or `SparseTensor`).
-
-Each continuous column in the train or test data will be converted into a
-`Tensor`, which in general is a good format to represent dense data. For
-categorical data, we must represent the data as a `SparseTensor`. This data
-format is good for representing sparse data. Our `input_fn` uses the `tf.data`
-API, which makes it easy to apply transformations to our dataset:
-
-```python
-def input_fn(data_file, num_epochs, shuffle, batch_size):
- """Generate an input function for the Estimator."""
- assert tf.gfile.Exists(data_file), (
- '%s not found. Please make sure you have either run data_download.py or '
- 'set both arguments --train_data and --test_data.' % data_file)
-
- def parse_csv(value):
- print('Parsing', data_file)
- columns = tf.decode_csv(value, record_defaults=_CSV_COLUMN_DEFAULTS)
- features = dict(zip(_CSV_COLUMNS, columns))
- labels = features.pop('income_bracket')
- return features, tf.equal(labels, '>50K')
-
- # Extract lines from input files using the Dataset API.
- dataset = tf.data.TextLineDataset(data_file)
-
- if shuffle:
- dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)
-
- dataset = dataset.map(parse_csv, num_parallel_calls=5)
-
- # We call repeat after shuffling, rather than before, to prevent separate
- # epochs from blending together.
- dataset = dataset.repeat(num_epochs)
- dataset = dataset.batch(batch_size)
-
- iterator = dataset.make_one_shot_iterator()
- features, labels = iterator.get_next()
- return features, labels
-```
-
-## Selecting and Engineering Features for the Model
-
-Selecting and crafting the right set of feature columns is key to learning an
-effective model. A **feature column** can be either one of the raw columns in
-the original dataframe (let's call them **base feature columns**), or any new
-columns created based on some transformations defined over one or multiple base
-columns (let's call them **derived feature columns**). Basically, "feature
-column" is an abstract concept of any raw or derived variable that can be used
-to predict the target label.
-
-### Base Categorical Feature Columns
-
-To define a feature column for a categorical feature, we can create a
-`CategoricalColumn` using the tf.feature_column API. If you know the set of all
-possible feature values of a column and there are only a few of them, you can
-use `categorical_column_with_vocabulary_list`. Each key in the list will get
-assigned an auto-incremental ID starting from 0. For example, for the
-`relationship` column we can assign the feature string "Husband" to an integer
-ID of 0 and "Not-in-family" to 1, etc., by doing:
-
-```python
-relationship = tf.feature_column.categorical_column_with_vocabulary_list(
- 'relationship', [
- 'Husband', 'Not-in-family', 'Wife', 'Own-child', 'Unmarried',
- 'Other-relative'])
-```
-
-What if we don't know the set of possible values in advance? Not a problem. We
-can use `categorical_column_with_hash_bucket` instead:
-
-```python
-occupation = tf.feature_column.categorical_column_with_hash_bucket(
- 'occupation', hash_bucket_size=1000)
-```
-
-What will happen is that each possible value in the feature column `occupation`
-will be hashed to an integer ID as we encounter them in training. See an example
-illustration below:
-
-ID | Feature
---- | -------------
-... |
-9 | `"Machine-op-inspct"`
-... |
-103 | `"Farming-fishing"`
-... |
-375 | `"Protective-serv"`
-... |
-
-No matter which way we choose to define a `SparseColumn`, each feature string
-will be mapped into an integer ID by looking up a fixed mapping or by hashing.
-Note that hashing collisions are possible, but may not significantly impact the
-model quality. Under the hood, the `LinearModel` class is responsible for
-managing the mapping and creating `tf.Variable` to store the model parameters
-(also known as model weights) for each feature ID. The model parameters will be
-learned through the model training process we'll go through later.
-
-We'll do the similar trick to define the other categorical features:
-
-```python
-education = tf.feature_column.categorical_column_with_vocabulary_list(
- 'education', [
- 'Bachelors', 'HS-grad', '11th', 'Masters', '9th', 'Some-college',
- 'Assoc-acdm', 'Assoc-voc', '7th-8th', 'Doctorate', 'Prof-school',
- '5th-6th', '10th', '1st-4th', 'Preschool', '12th'])
-
-marital_status = tf.feature_column.categorical_column_with_vocabulary_list(
- 'marital_status', [
- 'Married-civ-spouse', 'Divorced', 'Married-spouse-absent',
- 'Never-married', 'Separated', 'Married-AF-spouse', 'Widowed'])
-
-relationship = tf.feature_column.categorical_column_with_vocabulary_list(
- 'relationship', [
- 'Husband', 'Not-in-family', 'Wife', 'Own-child', 'Unmarried',
- 'Other-relative'])
-
-workclass = tf.feature_column.categorical_column_with_vocabulary_list(
- 'workclass', [
- 'Self-emp-not-inc', 'Private', 'State-gov', 'Federal-gov',
- 'Local-gov', '?', 'Self-emp-inc', 'Without-pay', 'Never-worked'])
-
-# To show an example of hashing:
-occupation = tf.feature_column.categorical_column_with_hash_bucket(
- 'occupation', hash_bucket_size=1000)
-```
-
-### Base Continuous Feature Columns
-
-Similarly, we can define a `NumericColumn` for each continuous feature column
-that we want to use in the model:
-
-```python
-age = tf.feature_column.numeric_column('age')
-education_num = tf.feature_column.numeric_column('education_num')
-capital_gain = tf.feature_column.numeric_column('capital_gain')
-capital_loss = tf.feature_column.numeric_column('capital_loss')
-hours_per_week = tf.feature_column.numeric_column('hours_per_week')
-```
-
-### Making Continuous Features Categorical through Bucketization
-
-Sometimes the relationship between a continuous feature and the label is not
-linear. As a hypothetical example, a person's income may grow with age in the
-early stage of one's career, then the growth may slow at some point, and finally
-the income decreases after retirement. In this scenario, using the raw `age` as
-a real-valued feature column might not be a good choice because the model can
-only learn one of the three cases:
-
-1. Income always increases at some rate as age grows (positive correlation),
-1. Income always decreases at some rate as age grows (negative correlation), or
-1. Income stays the same no matter at what age (no correlation)
-
-If we want to learn the fine-grained correlation between income and each age
-group separately, we can leverage **bucketization**. Bucketization is a process
-of dividing the entire range of a continuous feature into a set of consecutive
-bins/buckets, and then converting the original numerical feature into a bucket
-ID (as a categorical feature) depending on which bucket that value falls into.
-So, we can define a `bucketized_column` over `age` as:
-
-```python
-age_buckets = tf.feature_column.bucketized_column(
- age, boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65])
-```
-
-where the `boundaries` is a list of bucket boundaries. In this case, there are
-10 boundaries, resulting in 11 age group buckets (from age 17 and below, 18-24,
-25-29, ..., to 65 and over).
-
-### Intersecting Multiple Columns with CrossedColumn
-
-Using each base feature column separately may not be enough to explain the data.
-For example, the correlation between education and the label (earning > 50,000
-dollars) may be different for different occupations. Therefore, if we only learn
-a single model weight for `education="Bachelors"` and `education="Masters"`, we
-won't be able to capture every single education-occupation combination (e.g.
-distinguishing between `education="Bachelors" AND occupation="Exec-managerial"`
-and `education="Bachelors" AND occupation="Craft-repair"`). To learn the
-differences between different feature combinations, we can add **crossed feature
-columns** to the model.
-
-```python
-education_x_occupation = tf.feature_column.crossed_column(
- ['education', 'occupation'], hash_bucket_size=1000)
-```
-
-We can also create a `CrossedColumn` over more than two columns. Each
-constituent column can be either a base feature column that is categorical
-(`SparseColumn`), a bucketized real-valued feature column (`BucketizedColumn`),
-or even another `CrossColumn`. Here's an example:
-
-```python
-age_buckets_x_education_x_occupation = tf.feature_column.crossed_column(
- [age_buckets, 'education', 'occupation'], hash_bucket_size=1000)
-```
-
-## Defining The Logistic Regression Model
-
-After processing the input data and defining all the feature columns, we're now
-ready to put them all together and build a Logistic Regression model. In the
-previous section we've seen several types of base and derived feature columns,
-including:
-
-* `CategoricalColumn`
-* `NumericColumn`
-* `BucketizedColumn`
-* `CrossedColumn`
-
-All of these are subclasses of the abstract `FeatureColumn` class, and can be
-added to the `feature_columns` field of a model:
-
-```python
-base_columns = [
- education, marital_status, relationship, workclass, occupation,
- age_buckets,
-]
-crossed_columns = [
- tf.feature_column.crossed_column(
- ['education', 'occupation'], hash_bucket_size=1000),
- tf.feature_column.crossed_column(
- [age_buckets, 'education', 'occupation'], hash_bucket_size=1000),
-]
-
-model_dir = tempfile.mkdtemp()
-model = tf.estimator.LinearClassifier(
- model_dir=model_dir, feature_columns=base_columns + crossed_columns)
-```
-
-The model also automatically learns a bias term, which controls the prediction
-one would make without observing any features (see the section "How Logistic
-Regression Works" for more explanations). The learned model files will be stored
-in `model_dir`.
-
-## Training and Evaluating Our Model
-
-After adding all the features to the model, now let's look at how to actually
-train the model. Training a model is just a single command using the
-tf.estimator API:
-
-```python
-model.train(input_fn=lambda: input_fn(train_data, num_epochs, True, batch_size))
-```
-
-After the model is trained, we can evaluate how good our model is at predicting
-the labels of the holdout data:
-
-```python
-results = model.evaluate(input_fn=lambda: input_fn(
- test_data, 1, False, batch_size))
-for key in sorted(results):
- print('%s: %s' % (key, results[key]))
-```
-
-The first line of the final output should be something like
-`accuracy: 0.83557522`, which means the accuracy is 83.6%. Feel free to try more
-features and transformations and see if you can do even better!
-
-After the model is evaluated, we can use the model to predict whether an individual has an annual income of over
-50,000 dollars given an individual's information input.
-```python
- pred_iter = model.predict(input_fn=lambda: input_fn(FLAGS.test_data, 1, False, 1))
- for pred in pred_iter:
- print(pred['classes'])
-```
-
-The model prediction output would be like `[b'1']` or `[b'0']` which means whether corresponding individual has an annual income of over 50,000 dollars or not.
-
-If you'd like to see a working end-to-end example, you can download our
-[example code](https://github.com/tensorflow/models/tree/master/official/wide_deep/wide_deep.py)
-and set the `model_type` flag to `wide`.
-
-## Adding Regularization to Prevent Overfitting
-
-Regularization is a technique used to avoid **overfitting**. Overfitting happens
-when your model does well on the data it is trained on, but worse on test data
-that the model has not seen before, such as live traffic. Overfitting generally
-occurs when a model is excessively complex, such as having too many parameters
-relative to the number of observed training data. Regularization allows for you
-to control your model's complexity and makes the model more generalizable to
-unseen data.
-
-In the Linear Model library, you can add L1 and L2 regularizations to the model
-as:
-
-```
-model = tf.estimator.LinearClassifier(
- model_dir=model_dir, feature_columns=base_columns + crossed_columns,
- optimizer=tf.train.FtrlOptimizer(
- learning_rate=0.1,
- l1_regularization_strength=1.0,
- l2_regularization_strength=1.0))
-```
-
-One important difference between L1 and L2 regularization is that L1
-regularization tends to make model weights stay at zero, creating sparser
-models, whereas L2 regularization also tries to make the model weights closer to
-zero but not necessarily zero. Therefore, if you increase the strength of L1
-regularization, you will have a smaller model size because many of the model
-weights will be zero. This is often desirable when the feature space is very
-large but sparse, and when there are resource constraints that prevent you from
-serving a model that is too large.
-
-In practice, you should try various combinations of L1, L2 regularization
-strengths and find the best parameters that best control overfitting and give
-you a desirable model size.
-
-## How Logistic Regression Works
-
-Finally, let's take a minute to talk about what the Logistic Regression model
-actually looks like in case you're not already familiar with it. We'll denote
-the label as \\(Y\\), and the set of observed features as a feature vector
-\\(\mathbf{x}=[x_1, x_2, ..., x_d]\\). We define \\(Y=1\\) if an individual
-earned > 50,000 dollars and \\(Y=0\\) otherwise. In Logistic Regression, the
-probability of the label being positive (\\(Y=1\\)) given the features
-\\(\mathbf{x}\\) is given as:
-
-$$ P(Y=1|\mathbf{x}) = \frac{1}{1+\exp(-(\mathbf{w}^T\mathbf{x}+b))}$$
-
-where \\(\mathbf{w}=[w_1, w_2, ..., w_d]\\) are the model weights for the
-features \\(\mathbf{x}=[x_1, x_2, ..., x_d]\\). \\(b\\) is a constant that is
-often called the **bias** of the model. The equation consists of two parts—A
-linear model and a logistic function:
-
-* **Linear Model**: First, we can see that \\(\mathbf{w}^T\mathbf{x}+b = b +
- w_1x_1 + ... +w_dx_d\\) is a linear model where the output is a linear
- function of the input features \\(\mathbf{x}\\). The bias \\(b\\) is the
- prediction one would make without observing any features. The model weight
- \\(w_i\\) reflects how the feature \\(x_i\\) is correlated with the positive
- label. If \\(x_i\\) is positively correlated with the positive label, the
- weight \\(w_i\\) increases, and the probability \\(P(Y=1|\mathbf{x})\\) will
- be closer to 1. On the other hand, if \\(x_i\\) is negatively correlated
- with the positive label, then the weight \\(w_i\\) decreases and the
- probability \\(P(Y=1|\mathbf{x})\\) will be closer to 0.
-
-* **Logistic Function**: Second, we can see that there's a logistic function
- (also known as the sigmoid function) \\(S(t) = 1/(1+\exp(-t))\\) being
- applied to the linear model. The logistic function is used to convert the
- output of the linear model \\(\mathbf{w}^T\mathbf{x}+b\\) from any real
- number into the range of \\([0, 1]\\), which can be interpreted as a
- probability.
-
-Model training is an optimization problem: The goal is to find a set of model
-weights (i.e. model parameters) to minimize a **loss function** defined over the
-training data, such as logistic loss for Logistic Regression models. The loss
-function measures the discrepancy between the ground-truth label and the model's
-prediction. If the prediction is very close to the ground-truth label, the loss
-value will be low; if the prediction is very far from the label, then the loss
-value would be high.
-
-## Learn Deeper
-
-If you're interested in learning more, check out our
-@{$wide_and_deep$Wide & Deep Learning Tutorial} where we'll show you how to
-combine the strengths of linear models and deep neural networks by jointly
-training them using the tf.estimator API.
diff --git a/tensorflow/docs_src/tutorials/representation/wide_and_deep.md b/tensorflow/docs_src/tutorials/representation/wide_and_deep.md
deleted file mode 100644
index 44677a810b..0000000000
--- a/tensorflow/docs_src/tutorials/representation/wide_and_deep.md
+++ /dev/null
@@ -1,243 +0,0 @@
-# TensorFlow Wide & Deep Learning Tutorial
-
-In the previous @{$wide$TensorFlow Linear Model Tutorial}, we trained a logistic
-regression model to predict the probability that the individual has an annual
-income of over 50,000 dollars using the
-[Census Income Dataset](https://archive.ics.uci.edu/ml/datasets/Census+Income).
-TensorFlow is great for training deep neural networks too, and you might be
-thinking which one you should choose—well, why not both? Would it be possible to
-combine the strengths of both in one model?
-
-In this tutorial, we'll introduce how to use the tf.estimator API to jointly
-train a wide linear model and a deep feed-forward neural network. This approach
-combines the strengths of memorization and generalization. It's useful for
-generic large-scale regression and classification problems with sparse input
-features (e.g., categorical features with a large number of possible feature
-values). If you're interested in learning more about how Wide & Deep Learning
-works, please check out our [research paper](https://arxiv.org/abs/1606.07792).
-
-![Wide & Deep Spectrum of Models](https://www.tensorflow.org/images/wide_n_deep.svg "Wide & Deep")
-
-The figure above shows a comparison of a wide model (logistic regression with
-sparse features and transformations), a deep model (feed-forward neural network
-with an embedding layer and several hidden layers), and a Wide & Deep model
-(joint training of both). At a high level, there are only 3 steps to configure a
-wide, deep, or Wide & Deep model using the tf.estimator API:
-
-1. Select features for the wide part: Choose the sparse base columns and
- crossed columns you want to use.
-1. Select features for the deep part: Choose the continuous columns, the
- embedding dimension for each categorical column, and the hidden layer sizes.
-1. Put them all together in a Wide & Deep model
- (`DNNLinearCombinedClassifier`).
-
-And that's it! Let's go through a simple example.
-
-## Setup
-
-To try the code for this tutorial:
-
-1. @{$install$Install TensorFlow} if you haven't already.
-
-2. Download [the tutorial code](https://github.com/tensorflow/models/tree/master/official/wide_deep/).
-
-3. Execute the data download script we provide to you:
-
- $ python data_download.py
-
-4. Execute the tutorial code with the following command to train the wide and
-deep model described in this tutorial:
-
- $ python wide_deep.py
-
-Read on to find out how this code builds its model.
-
-
-## Define Base Feature Columns
-
-First, let's define the base categorical and continuous feature columns that
-we'll use. These base columns will be the building blocks used by both the wide
-part and the deep part of the model.
-
-```python
-import tensorflow as tf
-
-# Continuous columns
-age = tf.feature_column.numeric_column('age')
-education_num = tf.feature_column.numeric_column('education_num')
-capital_gain = tf.feature_column.numeric_column('capital_gain')
-capital_loss = tf.feature_column.numeric_column('capital_loss')
-hours_per_week = tf.feature_column.numeric_column('hours_per_week')
-
-education = tf.feature_column.categorical_column_with_vocabulary_list(
- 'education', [
- 'Bachelors', 'HS-grad', '11th', 'Masters', '9th', 'Some-college',
- 'Assoc-acdm', 'Assoc-voc', '7th-8th', 'Doctorate', 'Prof-school',
- '5th-6th', '10th', '1st-4th', 'Preschool', '12th'])
-
-marital_status = tf.feature_column.categorical_column_with_vocabulary_list(
- 'marital_status', [
- 'Married-civ-spouse', 'Divorced', 'Married-spouse-absent',
- 'Never-married', 'Separated', 'Married-AF-spouse', 'Widowed'])
-
-relationship = tf.feature_column.categorical_column_with_vocabulary_list(
- 'relationship', [
- 'Husband', 'Not-in-family', 'Wife', 'Own-child', 'Unmarried',
- 'Other-relative'])
-
-workclass = tf.feature_column.categorical_column_with_vocabulary_list(
- 'workclass', [
- 'Self-emp-not-inc', 'Private', 'State-gov', 'Federal-gov',
- 'Local-gov', '?', 'Self-emp-inc', 'Without-pay', 'Never-worked'])
-
-# To show an example of hashing:
-occupation = tf.feature_column.categorical_column_with_hash_bucket(
- 'occupation', hash_bucket_size=1000)
-
-# Transformations.
-age_buckets = tf.feature_column.bucketized_column(
- age, boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65])
-```
-
-## The Wide Model: Linear Model with Crossed Feature Columns
-
-The wide model is a linear model with a wide set of sparse and crossed feature
-columns:
-
-```python
-base_columns = [
- education, marital_status, relationship, workclass, occupation,
- age_buckets,
-]
-
-crossed_columns = [
- tf.feature_column.crossed_column(
- ['education', 'occupation'], hash_bucket_size=1000),
- tf.feature_column.crossed_column(
- [age_buckets, 'education', 'occupation'], hash_bucket_size=1000),
-]
-```
-
-You can also see the @{$wide$TensorFlow Linear Model Tutorial} for more details.
-
-Wide models with crossed feature columns can memorize sparse interactions
-between features effectively. That being said, one limitation of crossed feature
-columns is that they do not generalize to feature combinations that have not
-appeared in the training data. Let's add a deep model with embeddings to fix
-that.
-
-## The Deep Model: Neural Network with Embeddings
-
-The deep model is a feed-forward neural network, as shown in the previous
-figure. Each of the sparse, high-dimensional categorical features are first
-converted into a low-dimensional and dense real-valued vector, often referred to
-as an embedding vector. These low-dimensional dense embedding vectors are
-concatenated with the continuous features, and then fed into the hidden layers
-of a neural network in the forward pass. The embedding values are initialized
-randomly, and are trained along with all other model parameters to minimize the
-training loss. If you're interested in learning more about embeddings, check out
-the TensorFlow tutorial on @{$word2vec$Vector Representations of Words} or
-[Word embedding](https://en.wikipedia.org/wiki/Word_embedding) on Wikipedia.
-
-Another way to represent categorical columns to feed into a neural network is
-via a one-hot or multi-hot representation. This is often appropriate for
-categorical columns with only a few possible values. As an example of a one-hot
-representation, for the relationship column, `"Husband"` can be represented as
-[1, 0, 0, 0, 0, 0], and `"Not-in-family"` as [0, 1, 0, 0, 0, 0], etc. This is a
-fixed representation, whereas embeddings are more flexible and calculated at
-training time.
-
-We'll configure the embeddings for the categorical columns using
-`embedding_column`, and concatenate them with the continuous columns.
-We also use `indicator_column` to create multi-hot representations of some
-categorical columns.
-
-```python
-deep_columns = [
- age,
- education_num,
- capital_gain,
- capital_loss,
- hours_per_week,
- tf.feature_column.indicator_column(workclass),
- tf.feature_column.indicator_column(education),
- tf.feature_column.indicator_column(marital_status),
- tf.feature_column.indicator_column(relationship),
- # To show an example of embedding
- tf.feature_column.embedding_column(occupation, dimension=8),
-]
-```
-
-The higher the `dimension` of the embedding is, the more degrees of freedom the
-model will have to learn the representations of the features. For simplicity, we
-set the dimension to 8 for all feature columns here. Empirically, a more
-informed decision for the number of dimensions is to start with a value on the
-order of \\(\log_2(n)\\) or \\(k\sqrt[4]n\\), where \\(n\\) is the number of
-unique features in a feature column and \\(k\\) is a small constant (usually
-smaller than 10).
-
-Through dense embeddings, deep models can generalize better and make predictions
-on feature pairs that were previously unseen in the training data. However, it
-is difficult to learn effective low-dimensional representations for feature
-columns when the underlying interaction matrix between two feature columns is
-sparse and high-rank. In such cases, the interaction between most feature pairs
-should be zero except a few, but dense embeddings will lead to nonzero
-predictions for all feature pairs, and thus can over-generalize. On the other
-hand, linear models with crossed features can memorize these “exception rules”
-effectively with fewer model parameters.
-
-Now, let's see how to jointly train wide and deep models and allow them to
-complement each other’s strengths and weaknesses.
-
-## Combining Wide and Deep Models into One
-
-The wide models and deep models are combined by summing up their final output
-log odds as the prediction, then feeding the prediction to a logistic loss
-function. All the graph definition and variable allocations have already been
-handled for you under the hood, so you simply need to create a
-`DNNLinearCombinedClassifier`:
-
-```python
-model = tf.estimator.DNNLinearCombinedClassifier(
- model_dir='/tmp/census_model',
- linear_feature_columns=base_columns + crossed_columns,
- dnn_feature_columns=deep_columns,
- dnn_hidden_units=[100, 50])
-```
-
-## Training and Evaluating The Model
-
-Before we train the model, let's read in the Census dataset as we did in the
-@{$wide$TensorFlow Linear Model tutorial}. See `data_download.py` as well as
-`input_fn` within
-[`wide_deep.py`](https://github.com/tensorflow/models/tree/master/official/wide_deep/wide_deep.py).
-
-After reading in the data, you can train and evaluate the model:
-
-```python
-# Train and evaluate the model every `FLAGS.epochs_per_eval` epochs.
-for n in range(FLAGS.train_epochs // FLAGS.epochs_per_eval):
- model.train(input_fn=lambda: input_fn(
- FLAGS.train_data, FLAGS.epochs_per_eval, True, FLAGS.batch_size))
-
- results = model.evaluate(input_fn=lambda: input_fn(
- FLAGS.test_data, 1, False, FLAGS.batch_size))
-
- # Display evaluation metrics
- print('Results at epoch', (n + 1) * FLAGS.epochs_per_eval)
- print('-' * 30)
-
- for key in sorted(results):
- print('%s: %s' % (key, results[key]))
-```
-
-The final output accuracy should be somewhere around 85.5%. If you'd like to
-see a working end-to-end example, you can download our
-[example code](https://github.com/tensorflow/models/tree/master/official/wide_deep/wide_deep.py).
-
-Note that this tutorial is just a quick example on a small dataset to get you
-familiar with the API. Wide & Deep Learning will be even more powerful if you
-try it on a large dataset with many sparse feature columns that have a large
-number of possible feature values. Again, feel free to take a look at our
-[research paper](https://arxiv.org/abs/1606.07792) for more ideas about how to
-apply Wide & Deep Learning in real-world large-scale machine learning problems.
diff --git a/tensorflow/docs_src/tutorials/representation/word2vec.md b/tensorflow/docs_src/tutorials/representation/word2vec.md
index 3fe7352bd2..0a1c41c84a 100644
--- a/tensorflow/docs_src/tutorials/representation/word2vec.md
+++ b/tensorflow/docs_src/tutorials/representation/word2vec.md
@@ -23,7 +23,7 @@ straight in, feel free to look at the minimalistic implementation in
This basic example contains the code needed to download some data, train on it a
bit and visualize the result. Once you get comfortable with reading and running
the basic version, you can graduate to
-[models/tutorials/embedding/word2vec.py](https://www.tensorflow.org/code/tensorflow_models/tutorials/embedding/word2vec.py)
+[models/tutorials/embedding/word2vec.py](https://github.com/tensorflow/models/tree/master/tutorials/embedding/word2vec.py)
which is a more serious implementation that showcases some more advanced
TensorFlow principles about how to efficiently use threads to move data into a
text model, how to checkpoint during training, etc.
@@ -341,7 +341,7 @@ t-SNE.
Et voila! As expected, words that are similar end up clustering nearby each
other. For a more heavyweight implementation of word2vec that showcases more of
the advanced features of TensorFlow, see the implementation in
-[models/tutorials/embedding/word2vec.py](https://www.tensorflow.org/code/tensorflow_models/tutorials/embedding/word2vec.py).
+[models/tutorials/embedding/word2vec.py](https://github.com/tensorflow/models/tree/master/tutorials/embedding/word2vec.py).
## Evaluating Embeddings: Analogical Reasoning
@@ -357,7 +357,7 @@ Download the dataset for this task from
To see how we do this evaluation, have a look at the `build_eval_graph()` and
`eval()` functions in
-[models/tutorials/embedding/word2vec.py](https://www.tensorflow.org/code/tensorflow_models/tutorials/embedding/word2vec.py).
+[models/tutorials/embedding/word2vec.py](https://github.com/tensorflow/models/tree/master/tutorials/embedding/word2vec.py).
The choice of hyperparameters can strongly influence the accuracy on this task.
To achieve state-of-the-art performance on this task requires training over a
@@ -385,13 +385,13 @@ your model is seriously bottlenecked on input data, you may want to implement a
custom data reader for your problem, as described in
@{$new_data_formats$New Data Formats}. For the case of Skip-Gram
modeling, we've actually already done this for you as an example in
-[models/tutorials/embedding/word2vec.py](https://www.tensorflow.org/code/tensorflow_models/tutorials/embedding/word2vec.py).
+[models/tutorials/embedding/word2vec.py](https://github.com/tensorflow/models/tree/master/tutorials/embedding/word2vec.py).
If your model is no longer I/O bound but you want still more performance, you
can take things further by writing your own TensorFlow Ops, as described in
@{$adding_an_op$Adding a New Op}. Again we've provided an
example of this for the Skip-Gram case
-[models/tutorials/embedding/word2vec_optimized.py](https://www.tensorflow.org/code/tensorflow_models/tutorials/embedding/word2vec_optimized.py).
+[models/tutorials/embedding/word2vec_optimized.py](https://github.com/tensorflow/models/tree/master/tutorials/embedding/word2vec_optimized.py).
Feel free to benchmark these against each other to measure performance
improvements at each stage.
diff --git a/tensorflow/examples/speech_commands/BUILD b/tensorflow/examples/speech_commands/BUILD
index 13bca34a86..7a44e2ee4f 100644
--- a/tensorflow/examples/speech_commands/BUILD
+++ b/tensorflow/examples/speech_commands/BUILD
@@ -56,6 +56,7 @@ tf_py_test(
srcs = ["input_data_test.py"],
additional_deps = [
":input_data",
+ ":models",
"//tensorflow/python:client_testlib",
],
)
diff --git a/tensorflow/examples/speech_commands/freeze.py b/tensorflow/examples/speech_commands/freeze.py
index c8671d9c41..7657b23c60 100644
--- a/tensorflow/examples/speech_commands/freeze.py
+++ b/tensorflow/examples/speech_commands/freeze.py
@@ -54,7 +54,7 @@ FLAGS = None
def create_inference_graph(wanted_words, sample_rate, clip_duration_ms,
clip_stride_ms, window_size_ms, window_stride_ms,
- dct_coefficient_count, model_architecture):
+ feature_bin_count, model_architecture, preprocess):
"""Creates an audio model with the nodes needed for inference.
Uses the supplied arguments to create a model, and inserts the input and
@@ -67,14 +67,19 @@ def create_inference_graph(wanted_words, sample_rate, clip_duration_ms,
clip_stride_ms: How often to run recognition. Useful for models with cache.
window_size_ms: Time slice duration to estimate frequencies from.
window_stride_ms: How far apart time slices should be.
- dct_coefficient_count: Number of frequency bands to analyze.
+ feature_bin_count: Number of frequency bands to analyze.
model_architecture: Name of the kind of model to generate.
+ preprocess: How the spectrogram is processed to produce features, for
+ example 'mfcc' or 'average'.
+
+ Raises:
+ Exception: If the preprocessing mode isn't recognized.
"""
words_list = input_data.prepare_words_list(wanted_words.split(','))
model_settings = models.prepare_model_settings(
len(words_list), sample_rate, clip_duration_ms, window_size_ms,
- window_stride_ms, dct_coefficient_count)
+ window_stride_ms, feature_bin_count, preprocess)
runtime_settings = {'clip_stride_ms': clip_stride_ms}
wav_data_placeholder = tf.placeholder(tf.string, [], name='wav_data')
@@ -88,15 +93,25 @@ def create_inference_graph(wanted_words, sample_rate, clip_duration_ms,
window_size=model_settings['window_size_samples'],
stride=model_settings['window_stride_samples'],
magnitude_squared=True)
- fingerprint_input = contrib_audio.mfcc(
- spectrogram,
- decoded_sample_data.sample_rate,
- dct_coefficient_count=dct_coefficient_count)
- fingerprint_frequency_size = model_settings['dct_coefficient_count']
- fingerprint_time_size = model_settings['spectrogram_length']
- reshaped_input = tf.reshape(fingerprint_input, [
- -1, fingerprint_time_size * fingerprint_frequency_size
- ])
+
+ if preprocess == 'average':
+ fingerprint_input = tf.nn.pool(
+ tf.expand_dims(spectrogram, -1),
+ window_shape=[1, model_settings['average_window_width']],
+ strides=[1, model_settings['average_window_width']],
+ pooling_type='AVG',
+ padding='SAME')
+ elif preprocess == 'mfcc':
+ fingerprint_input = contrib_audio.mfcc(
+ spectrogram,
+ sample_rate,
+ dct_coefficient_count=model_settings['fingerprint_width'])
+ else:
+ raise Exception('Unknown preprocess mode "%s" (should be "mfcc" or'
+ ' "average")' % (preprocess))
+
+ fingerprint_size = model_settings['fingerprint_size']
+ reshaped_input = tf.reshape(fingerprint_input, [-1, fingerprint_size])
logits = models.create_model(
reshaped_input, model_settings, model_architecture, is_training=False,
@@ -110,10 +125,12 @@ def main(_):
# Create the model and load its weights.
sess = tf.InteractiveSession()
- create_inference_graph(FLAGS.wanted_words, FLAGS.sample_rate,
- FLAGS.clip_duration_ms, FLAGS.clip_stride_ms,
- FLAGS.window_size_ms, FLAGS.window_stride_ms,
- FLAGS.dct_coefficient_count, FLAGS.model_architecture)
+ create_inference_graph(
+ FLAGS.wanted_words, FLAGS.sample_rate, FLAGS.clip_duration_ms,
+ FLAGS.clip_stride_ms, FLAGS.window_size_ms, FLAGS.window_stride_ms,
+ FLAGS.feature_bin_count, FLAGS.model_architecture, FLAGS.preprocess)
+ if FLAGS.quantize:
+ tf.contrib.quantize.create_training_graph(quant_delay=0)
models.load_variables_from_checkpoint(sess, FLAGS.start_checkpoint)
# Turn all the variables into inline constants inside the graph and save it.
@@ -155,10 +172,11 @@ if __name__ == '__main__':
default=10.0,
help='How long the stride is between spectrogram timeslices',)
parser.add_argument(
- '--dct_coefficient_count',
+ '--feature_bin_count',
type=int,
default=40,
- help='How many bins to use for the MFCC fingerprint',)
+ help='How many bins to use for the MFCC fingerprint',
+ )
parser.add_argument(
'--start_checkpoint',
type=str,
@@ -176,5 +194,15 @@ if __name__ == '__main__':
help='Words to use (others will be added to an unknown label)',)
parser.add_argument(
'--output_file', type=str, help='Where to save the frozen graph.')
+ parser.add_argument(
+ '--quantize',
+ type=bool,
+ default=False,
+ help='Whether to train the model for eight-bit deployment')
+ parser.add_argument(
+ '--preprocess',
+ type=str,
+ default='mfcc',
+ help='Spectrogram processing mode. Can be "mfcc" or "average"')
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/examples/speech_commands/freeze_test.py b/tensorflow/examples/speech_commands/freeze_test.py
index 97c6eac675..c8de6c2152 100644
--- a/tensorflow/examples/speech_commands/freeze_test.py
+++ b/tensorflow/examples/speech_commands/freeze_test.py
@@ -24,14 +24,62 @@ from tensorflow.python.platform import test
class FreezeTest(test.TestCase):
- def testCreateInferenceGraph(self):
+ def testCreateInferenceGraphWithMfcc(self):
with self.test_session() as sess:
- freeze.create_inference_graph('a,b,c,d', 16000, 1000.0, 30.0, 30.0, 10.0,
- 40, 'conv')
+ freeze.create_inference_graph(
+ wanted_words='a,b,c,d',
+ sample_rate=16000,
+ clip_duration_ms=1000.0,
+ clip_stride_ms=30.0,
+ window_size_ms=30.0,
+ window_stride_ms=10.0,
+ feature_bin_count=40,
+ model_architecture='conv',
+ preprocess='mfcc')
self.assertIsNotNone(sess.graph.get_tensor_by_name('wav_data:0'))
self.assertIsNotNone(
sess.graph.get_tensor_by_name('decoded_sample_data:0'))
self.assertIsNotNone(sess.graph.get_tensor_by_name('labels_softmax:0'))
+ ops = [node.op for node in sess.graph_def.node]
+ self.assertEqual(1, ops.count('Mfcc'))
+
+ def testCreateInferenceGraphWithoutMfcc(self):
+ with self.test_session() as sess:
+ freeze.create_inference_graph(
+ wanted_words='a,b,c,d',
+ sample_rate=16000,
+ clip_duration_ms=1000.0,
+ clip_stride_ms=30.0,
+ window_size_ms=30.0,
+ window_stride_ms=10.0,
+ feature_bin_count=40,
+ model_architecture='conv',
+ preprocess='average')
+ self.assertIsNotNone(sess.graph.get_tensor_by_name('wav_data:0'))
+ self.assertIsNotNone(
+ sess.graph.get_tensor_by_name('decoded_sample_data:0'))
+ self.assertIsNotNone(sess.graph.get_tensor_by_name('labels_softmax:0'))
+ ops = [node.op for node in sess.graph_def.node]
+ self.assertEqual(0, ops.count('Mfcc'))
+
+ def testFeatureBinCount(self):
+ with self.test_session() as sess:
+ freeze.create_inference_graph(
+ wanted_words='a,b,c,d',
+ sample_rate=16000,
+ clip_duration_ms=1000.0,
+ clip_stride_ms=30.0,
+ window_size_ms=30.0,
+ window_stride_ms=10.0,
+ feature_bin_count=80,
+ model_architecture='conv',
+ preprocess='average')
+ self.assertIsNotNone(sess.graph.get_tensor_by_name('wav_data:0'))
+ self.assertIsNotNone(
+ sess.graph.get_tensor_by_name('decoded_sample_data:0'))
+ self.assertIsNotNone(sess.graph.get_tensor_by_name('labels_softmax:0'))
+ ops = [node.op for node in sess.graph_def.node]
+ self.assertEqual(0, ops.count('Mfcc'))
if __name__ == '__main__':
diff --git a/tensorflow/examples/speech_commands/generate_streaming_test_wav.py b/tensorflow/examples/speech_commands/generate_streaming_test_wav.py
index 053206ae2f..9858906927 100644
--- a/tensorflow/examples/speech_commands/generate_streaming_test_wav.py
+++ b/tensorflow/examples/speech_commands/generate_streaming_test_wav.py
@@ -87,11 +87,12 @@ def main(_):
words_list = input_data.prepare_words_list(FLAGS.wanted_words.split(','))
model_settings = models.prepare_model_settings(
len(words_list), FLAGS.sample_rate, FLAGS.clip_duration_ms,
- FLAGS.window_size_ms, FLAGS.window_stride_ms, FLAGS.dct_coefficient_count)
+ FLAGS.window_size_ms, FLAGS.window_stride_ms, FLAGS.feature_bin_count,
+ 'mfcc')
audio_processor = input_data.AudioProcessor(
'', FLAGS.data_dir, FLAGS.silence_percentage, 10,
FLAGS.wanted_words.split(','), FLAGS.validation_percentage,
- FLAGS.testing_percentage, model_settings)
+ FLAGS.testing_percentage, model_settings, FLAGS.data_dir)
output_audio_sample_count = FLAGS.sample_rate * FLAGS.test_duration_seconds
output_audio = np.zeros((output_audio_sample_count,), dtype=np.float32)
@@ -242,10 +243,11 @@ if __name__ == '__main__':
default=10.0,
help='How long the stride is between spectrogram timeslices',)
parser.add_argument(
- '--dct_coefficient_count',
+ '--feature_bin_count',
type=int,
default=40,
- help='How many bins to use for the MFCC fingerprint',)
+ help='How many bins to use for the MFCC fingerprint',
+ )
parser.add_argument(
'--wanted_words',
type=str,
diff --git a/tensorflow/examples/speech_commands/input_data.py b/tensorflow/examples/speech_commands/input_data.py
index 63dd18457f..30f2cfa9fe 100644
--- a/tensorflow/examples/speech_commands/input_data.py
+++ b/tensorflow/examples/speech_commands/input_data.py
@@ -153,14 +153,14 @@ class AudioProcessor(object):
def __init__(self, data_url, data_dir, silence_percentage, unknown_percentage,
wanted_words, validation_percentage, testing_percentage,
- model_settings):
+ model_settings, summaries_dir):
self.data_dir = data_dir
self.maybe_download_and_extract_dataset(data_url, data_dir)
self.prepare_data_index(silence_percentage, unknown_percentage,
wanted_words, validation_percentage,
testing_percentage)
self.prepare_background_data()
- self.prepare_processing_graph(model_settings)
+ self.prepare_processing_graph(model_settings, summaries_dir)
def maybe_download_and_extract_dataset(self, data_url, dest_directory):
"""Download and extract data set tar file.
@@ -325,7 +325,7 @@ class AudioProcessor(object):
if not self.background_data:
raise Exception('No background wav files were found in ' + search_path)
- def prepare_processing_graph(self, model_settings):
+ def prepare_processing_graph(self, model_settings, summaries_dir):
"""Builds a TensorFlow graph to apply the input distortions.
Creates a graph that loads a WAVE file, decodes it, scales the volume,
@@ -341,48 +341,88 @@ class AudioProcessor(object):
- time_shift_offset_placeholder_: How much to move the clip in time.
- background_data_placeholder_: PCM sample data for background noise.
- background_volume_placeholder_: Loudness of mixed-in background.
- - mfcc_: Output 2D fingerprint of processed audio.
+ - output_: Output 2D fingerprint of processed audio.
Args:
model_settings: Information about the current model being trained.
+ summaries_dir: Path to save training summary information to.
+
+ Raises:
+ ValueError: If the preprocessing mode isn't recognized.
"""
- desired_samples = model_settings['desired_samples']
- self.wav_filename_placeholder_ = tf.placeholder(tf.string, [])
- wav_loader = io_ops.read_file(self.wav_filename_placeholder_)
- wav_decoder = contrib_audio.decode_wav(
- wav_loader, desired_channels=1, desired_samples=desired_samples)
- # Allow the audio sample's volume to be adjusted.
- self.foreground_volume_placeholder_ = tf.placeholder(tf.float32, [])
- scaled_foreground = tf.multiply(wav_decoder.audio,
- self.foreground_volume_placeholder_)
- # Shift the sample's start position, and pad any gaps with zeros.
- self.time_shift_padding_placeholder_ = tf.placeholder(tf.int32, [2, 2])
- self.time_shift_offset_placeholder_ = tf.placeholder(tf.int32, [2])
- padded_foreground = tf.pad(
- scaled_foreground,
- self.time_shift_padding_placeholder_,
- mode='CONSTANT')
- sliced_foreground = tf.slice(padded_foreground,
- self.time_shift_offset_placeholder_,
- [desired_samples, -1])
- # Mix in background noise.
- self.background_data_placeholder_ = tf.placeholder(tf.float32,
- [desired_samples, 1])
- self.background_volume_placeholder_ = tf.placeholder(tf.float32, [])
- background_mul = tf.multiply(self.background_data_placeholder_,
- self.background_volume_placeholder_)
- background_add = tf.add(background_mul, sliced_foreground)
- background_clamp = tf.clip_by_value(background_add, -1.0, 1.0)
- # Run the spectrogram and MFCC ops to get a 2D 'fingerprint' of the audio.
- spectrogram = contrib_audio.audio_spectrogram(
- background_clamp,
- window_size=model_settings['window_size_samples'],
- stride=model_settings['window_stride_samples'],
- magnitude_squared=True)
- self.mfcc_ = contrib_audio.mfcc(
- spectrogram,
- wav_decoder.sample_rate,
- dct_coefficient_count=model_settings['dct_coefficient_count'])
+ with tf.get_default_graph().name_scope('data'):
+ desired_samples = model_settings['desired_samples']
+ self.wav_filename_placeholder_ = tf.placeholder(
+ tf.string, [], name='wav_filename')
+ wav_loader = io_ops.read_file(self.wav_filename_placeholder_)
+ wav_decoder = contrib_audio.decode_wav(
+ wav_loader, desired_channels=1, desired_samples=desired_samples)
+ # Allow the audio sample's volume to be adjusted.
+ self.foreground_volume_placeholder_ = tf.placeholder(
+ tf.float32, [], name='foreground_volume')
+ scaled_foreground = tf.multiply(wav_decoder.audio,
+ self.foreground_volume_placeholder_)
+ # Shift the sample's start position, and pad any gaps with zeros.
+ self.time_shift_padding_placeholder_ = tf.placeholder(
+ tf.int32, [2, 2], name='time_shift_padding')
+ self.time_shift_offset_placeholder_ = tf.placeholder(
+ tf.int32, [2], name='time_shift_offset')
+ padded_foreground = tf.pad(
+ scaled_foreground,
+ self.time_shift_padding_placeholder_,
+ mode='CONSTANT')
+ sliced_foreground = tf.slice(padded_foreground,
+ self.time_shift_offset_placeholder_,
+ [desired_samples, -1])
+ # Mix in background noise.
+ self.background_data_placeholder_ = tf.placeholder(
+ tf.float32, [desired_samples, 1], name='background_data')
+ self.background_volume_placeholder_ = tf.placeholder(
+ tf.float32, [], name='background_volume')
+ background_mul = tf.multiply(self.background_data_placeholder_,
+ self.background_volume_placeholder_)
+ background_add = tf.add(background_mul, sliced_foreground)
+ background_clamp = tf.clip_by_value(background_add, -1.0, 1.0)
+ # Run the spectrogram and MFCC ops to get a 2D 'fingerprint' of the audio.
+ spectrogram = contrib_audio.audio_spectrogram(
+ background_clamp,
+ window_size=model_settings['window_size_samples'],
+ stride=model_settings['window_stride_samples'],
+ magnitude_squared=True)
+ tf.summary.image(
+ 'spectrogram', tf.expand_dims(spectrogram, -1), max_outputs=1)
+ # The number of buckets in each FFT row in the spectrogram will depend on
+ # how many input samples there are in each window. This can be quite
+ # large, with a 160 sample window producing 127 buckets for example. We
+ # don't need this level of detail for classification, so we often want to
+ # shrink them down to produce a smaller result. That's what this section
+ # implements. One method is to use average pooling to merge adjacent
+ # buckets, but a more sophisticated approach is to apply the MFCC
+ # algorithm to shrink the representation.
+ if model_settings['preprocess'] == 'average':
+ self.output_ = tf.nn.pool(
+ tf.expand_dims(spectrogram, -1),
+ window_shape=[1, model_settings['average_window_width']],
+ strides=[1, model_settings['average_window_width']],
+ pooling_type='AVG',
+ padding='SAME')
+ tf.summary.image('shrunk_spectrogram', self.output_, max_outputs=1)
+ elif model_settings['preprocess'] == 'mfcc':
+ self.output_ = contrib_audio.mfcc(
+ spectrogram,
+ wav_decoder.sample_rate,
+ dct_coefficient_count=model_settings['fingerprint_width'])
+ tf.summary.image(
+ 'mfcc', tf.expand_dims(self.output_, -1), max_outputs=1)
+ else:
+ raise ValueError('Unknown preprocess mode "%s" (should be "mfcc" or'
+ ' "average")' % (model_settings['preprocess']))
+
+ # Merge all the summaries and write them out to /tmp/retrain_logs (by
+ # default)
+ self.merged_summaries_ = tf.summary.merge_all(scope='data')
+ self.summary_writer_ = tf.summary.FileWriter(summaries_dir + '/data',
+ tf.get_default_graph())
def set_size(self, mode):
"""Calculates the number of samples in the dataset partition.
@@ -418,6 +458,9 @@ class AudioProcessor(object):
Returns:
List of sample data for the transformed samples, and list of label indexes
+
+ Raises:
+ ValueError: If background samples are too short.
"""
# Pick one of the partitions to choose samples from.
candidates = self.data_index[mode]
@@ -460,6 +503,11 @@ class AudioProcessor(object):
if use_background or sample['label'] == SILENCE_LABEL:
background_index = np.random.randint(len(self.background_data))
background_samples = self.background_data[background_index]
+ if len(background_samples) <= model_settings['desired_samples']:
+ raise ValueError(
+ 'Background sample is too short! Need more than %d'
+ ' samples but only %d were found' %
+ (model_settings['desired_samples'], len(background_samples)))
background_offset = np.random.randint(
0, len(background_samples) - model_settings['desired_samples'])
background_clipped = background_samples[background_offset:(
@@ -482,7 +530,10 @@ class AudioProcessor(object):
else:
input_dict[self.foreground_volume_placeholder_] = 1
# Run the graph to produce the output audio.
- data[i - offset, :] = sess.run(self.mfcc_, feed_dict=input_dict).flatten()
+ summary, data_tensor = sess.run(
+ [self.merged_summaries_, self.output_], feed_dict=input_dict)
+ self.summary_writer_.add_summary(summary)
+ data[i - offset, :] = data_tensor.flatten()
label_index = self.word_to_index[sample['label']]
labels[i - offset] = label_index
return data, labels
diff --git a/tensorflow/examples/speech_commands/input_data_test.py b/tensorflow/examples/speech_commands/input_data_test.py
index 13f294d39d..2e551be9a2 100644
--- a/tensorflow/examples/speech_commands/input_data_test.py
+++ b/tensorflow/examples/speech_commands/input_data_test.py
@@ -25,6 +25,7 @@ import tensorflow as tf
from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio
from tensorflow.examples.speech_commands import input_data
+from tensorflow.examples.speech_commands import models
from tensorflow.python.platform import test
@@ -32,7 +33,7 @@ class InputDataTest(test.TestCase):
def _getWavData(self):
with self.test_session() as sess:
- sample_data = tf.zeros([1000, 2])
+ sample_data = tf.zeros([32000, 2])
wav_encoder = contrib_audio.encode_wav(sample_data, 16000)
wav_data = sess.run(wav_encoder)
return wav_data
@@ -57,9 +58,31 @@ class InputDataTest(test.TestCase):
"label_count": 4,
"window_size_samples": 100,
"window_stride_samples": 100,
- "dct_coefficient_count": 40,
+ "fingerprint_width": 40,
+ "preprocess": "mfcc",
}
+ def _runGetDataTest(self, preprocess, window_length_ms):
+ tmp_dir = self.get_temp_dir()
+ wav_dir = os.path.join(tmp_dir, "wavs")
+ os.mkdir(wav_dir)
+ self._saveWavFolders(wav_dir, ["a", "b", "c"], 100)
+ background_dir = os.path.join(wav_dir, "_background_noise_")
+ os.mkdir(background_dir)
+ wav_data = self._getWavData()
+ for i in range(10):
+ file_path = os.path.join(background_dir, "background_audio_%d.wav" % i)
+ self._saveTestWavFile(file_path, wav_data)
+ model_settings = models.prepare_model_settings(
+ 4, 16000, 1000, window_length_ms, 20, 40, preprocess)
+ with self.test_session() as sess:
+ audio_processor = input_data.AudioProcessor(
+ "", wav_dir, 10, 10, ["a", "b"], 10, 10, model_settings, tmp_dir)
+ result_data, result_labels = audio_processor.get_data(
+ 10, 0, model_settings, 0.3, 0.1, 100, "training", sess)
+ self.assertEqual(10, len(result_data))
+ self.assertEqual(10, len(result_labels))
+
def testPrepareWordsList(self):
words_list = ["a", "b"]
self.assertGreater(
@@ -76,8 +99,9 @@ class InputDataTest(test.TestCase):
def testPrepareDataIndex(self):
tmp_dir = self.get_temp_dir()
self._saveWavFolders(tmp_dir, ["a", "b", "c"], 100)
- audio_processor = input_data.AudioProcessor("", tmp_dir, 10, 10, ["a", "b"],
- 10, 10, self._model_settings())
+ audio_processor = input_data.AudioProcessor("", tmp_dir, 10, 10,
+ ["a", "b"], 10, 10,
+ self._model_settings(), tmp_dir)
self.assertLess(0, audio_processor.set_size("training"))
self.assertTrue("training" in audio_processor.data_index)
self.assertTrue("validation" in audio_processor.data_index)
@@ -90,7 +114,7 @@ class InputDataTest(test.TestCase):
self._saveWavFolders(tmp_dir, ["a", "b", "c"], 0)
with self.assertRaises(Exception) as e:
_ = input_data.AudioProcessor("", tmp_dir, 10, 10, ["a", "b"], 10, 10,
- self._model_settings())
+ self._model_settings(), tmp_dir)
self.assertTrue("No .wavs found" in str(e.exception))
def testPrepareDataIndexMissing(self):
@@ -98,7 +122,7 @@ class InputDataTest(test.TestCase):
self._saveWavFolders(tmp_dir, ["a", "b", "c"], 100)
with self.assertRaises(Exception) as e:
_ = input_data.AudioProcessor("", tmp_dir, 10, 10, ["a", "b", "d"], 10,
- 10, self._model_settings())
+ 10, self._model_settings(), tmp_dir)
self.assertTrue("Expected to find" in str(e.exception))
def testPrepareBackgroundData(self):
@@ -110,8 +134,9 @@ class InputDataTest(test.TestCase):
file_path = os.path.join(background_dir, "background_audio_%d.wav" % i)
self._saveTestWavFile(file_path, wav_data)
self._saveWavFolders(tmp_dir, ["a", "b", "c"], 100)
- audio_processor = input_data.AudioProcessor("", tmp_dir, 10, 10, ["a", "b"],
- 10, 10, self._model_settings())
+ audio_processor = input_data.AudioProcessor("", tmp_dir, 10, 10,
+ ["a", "b"], 10, 10,
+ self._model_settings(), tmp_dir)
self.assertEqual(10, len(audio_processor.background_data))
def testLoadWavFile(self):
@@ -148,44 +173,27 @@ class InputDataTest(test.TestCase):
"label_count": 4,
"window_size_samples": 100,
"window_stride_samples": 100,
- "dct_coefficient_count": 40,
+ "fingerprint_width": 40,
+ "preprocess": "mfcc",
}
audio_processor = input_data.AudioProcessor("", wav_dir, 10, 10, ["a", "b"],
- 10, 10, model_settings)
+ 10, 10, model_settings, tmp_dir)
self.assertIsNotNone(audio_processor.wav_filename_placeholder_)
self.assertIsNotNone(audio_processor.foreground_volume_placeholder_)
self.assertIsNotNone(audio_processor.time_shift_padding_placeholder_)
self.assertIsNotNone(audio_processor.time_shift_offset_placeholder_)
self.assertIsNotNone(audio_processor.background_data_placeholder_)
self.assertIsNotNone(audio_processor.background_volume_placeholder_)
- self.assertIsNotNone(audio_processor.mfcc_)
+ self.assertIsNotNone(audio_processor.output_)
- def testGetData(self):
- tmp_dir = self.get_temp_dir()
- wav_dir = os.path.join(tmp_dir, "wavs")
- os.mkdir(wav_dir)
- self._saveWavFolders(wav_dir, ["a", "b", "c"], 100)
- background_dir = os.path.join(wav_dir, "_background_noise_")
- os.mkdir(background_dir)
- wav_data = self._getWavData()
- for i in range(10):
- file_path = os.path.join(background_dir, "background_audio_%d.wav" % i)
- self._saveTestWavFile(file_path, wav_data)
- model_settings = {
- "desired_samples": 160,
- "fingerprint_size": 40,
- "label_count": 4,
- "window_size_samples": 100,
- "window_stride_samples": 100,
- "dct_coefficient_count": 40,
- }
- audio_processor = input_data.AudioProcessor("", wav_dir, 10, 10, ["a", "b"],
- 10, 10, model_settings)
- with self.test_session() as sess:
- result_data, result_labels = audio_processor.get_data(
- 10, 0, model_settings, 0.3, 0.1, 100, "training", sess)
- self.assertEqual(10, len(result_data))
- self.assertEqual(10, len(result_labels))
+ def testGetDataAverage(self):
+ self._runGetDataTest("average", 10)
+
+ def testGetDataAverageLongWindow(self):
+ self._runGetDataTest("average", 30)
+
+ def testGetDataMfcc(self):
+ self._runGetDataTest("mfcc", 30)
def testGetUnprocessedData(self):
tmp_dir = self.get_temp_dir()
@@ -198,10 +206,11 @@ class InputDataTest(test.TestCase):
"label_count": 4,
"window_size_samples": 100,
"window_stride_samples": 100,
- "dct_coefficient_count": 40,
+ "fingerprint_width": 40,
+ "preprocess": "mfcc",
}
audio_processor = input_data.AudioProcessor("", wav_dir, 10, 10, ["a", "b"],
- 10, 10, model_settings)
+ 10, 10, model_settings, tmp_dir)
result_data, result_labels = audio_processor.get_unprocessed_data(
10, model_settings, "training")
self.assertEqual(10, len(result_data))
diff --git a/tensorflow/examples/speech_commands/models.py b/tensorflow/examples/speech_commands/models.py
index ab611f414a..65ae3b1511 100644
--- a/tensorflow/examples/speech_commands/models.py
+++ b/tensorflow/examples/speech_commands/models.py
@@ -24,9 +24,21 @@ import math
import tensorflow as tf
+def _next_power_of_two(x):
+ """Calculates the smallest enclosing power of two for an input.
+
+ Args:
+ x: Positive float or integer number.
+
+ Returns:
+ Next largest power of two integer.
+ """
+ return 1 if x == 0 else 2**(int(x) - 1).bit_length()
+
+
def prepare_model_settings(label_count, sample_rate, clip_duration_ms,
- window_size_ms, window_stride_ms,
- dct_coefficient_count):
+ window_size_ms, window_stride_ms, feature_bin_count,
+ preprocess):
"""Calculates common settings needed for all models.
Args:
@@ -35,10 +47,14 @@ def prepare_model_settings(label_count, sample_rate, clip_duration_ms,
clip_duration_ms: Length of each audio clip to be analyzed.
window_size_ms: Duration of frequency analysis window.
window_stride_ms: How far to move in time between frequency windows.
- dct_coefficient_count: Number of frequency bins to use for analysis.
+ feature_bin_count: Number of frequency bins to use for analysis.
+ preprocess: How the spectrogram is processed to produce features.
Returns:
Dictionary containing common settings.
+
+ Raises:
+ ValueError: If the preprocessing mode isn't recognized.
"""
desired_samples = int(sample_rate * clip_duration_ms / 1000)
window_size_samples = int(sample_rate * window_size_ms / 1000)
@@ -48,16 +64,28 @@ def prepare_model_settings(label_count, sample_rate, clip_duration_ms,
spectrogram_length = 0
else:
spectrogram_length = 1 + int(length_minus_window / window_stride_samples)
- fingerprint_size = dct_coefficient_count * spectrogram_length
+ if preprocess == 'average':
+ fft_bin_count = 1 + (_next_power_of_two(window_size_samples) / 2)
+ average_window_width = int(math.floor(fft_bin_count / feature_bin_count))
+ fingerprint_width = int(math.ceil(fft_bin_count / average_window_width))
+ elif preprocess == 'mfcc':
+ average_window_width = -1
+ fingerprint_width = feature_bin_count
+ else:
+ raise ValueError('Unknown preprocess mode "%s" (should be "mfcc" or'
+ ' "average")' % (preprocess))
+ fingerprint_size = fingerprint_width * spectrogram_length
return {
'desired_samples': desired_samples,
'window_size_samples': window_size_samples,
'window_stride_samples': window_stride_samples,
'spectrogram_length': spectrogram_length,
- 'dct_coefficient_count': dct_coefficient_count,
+ 'fingerprint_width': fingerprint_width,
'fingerprint_size': fingerprint_size,
'label_count': label_count,
'sample_rate': sample_rate,
+ 'preprocess': preprocess,
+ 'average_window_width': average_window_width,
}
@@ -106,10 +134,14 @@ def create_model(fingerprint_input, model_settings, model_architecture,
elif model_architecture == 'low_latency_svdf':
return create_low_latency_svdf_model(fingerprint_input, model_settings,
is_training, runtime_settings)
+ elif model_architecture == 'tiny_conv':
+ return create_tiny_conv_model(fingerprint_input, model_settings,
+ is_training)
else:
raise Exception('model_architecture argument "' + model_architecture +
'" not recognized, should be one of "single_fc", "conv",' +
- ' "low_latency_conv, or "low_latency_svdf"')
+ ' "low_latency_conv, "low_latency_svdf",' +
+ ' or "tiny_conv"')
def load_variables_from_checkpoint(sess, start_checkpoint):
@@ -152,9 +184,12 @@ def create_single_fc_model(fingerprint_input, model_settings, is_training):
dropout_prob = tf.placeholder(tf.float32, name='dropout_prob')
fingerprint_size = model_settings['fingerprint_size']
label_count = model_settings['label_count']
- weights = tf.Variable(
- tf.truncated_normal([fingerprint_size, label_count], stddev=0.001))
- bias = tf.Variable(tf.zeros([label_count]))
+ weights = tf.get_variable(
+ name='weights',
+ initializer=tf.truncated_normal_initializer(stddev=0.001),
+ shape=[fingerprint_size, label_count])
+ bias = tf.get_variable(
+ name='bias', initializer=tf.zeros_initializer, shape=[label_count])
logits = tf.matmul(fingerprint_input, weights) + bias
if is_training:
return logits, dropout_prob
@@ -212,18 +247,21 @@ def create_conv_model(fingerprint_input, model_settings, is_training):
"""
if is_training:
dropout_prob = tf.placeholder(tf.float32, name='dropout_prob')
- input_frequency_size = model_settings['dct_coefficient_count']
+ input_frequency_size = model_settings['fingerprint_width']
input_time_size = model_settings['spectrogram_length']
fingerprint_4d = tf.reshape(fingerprint_input,
[-1, input_time_size, input_frequency_size, 1])
first_filter_width = 8
first_filter_height = 20
first_filter_count = 64
- first_weights = tf.Variable(
- tf.truncated_normal(
- [first_filter_height, first_filter_width, 1, first_filter_count],
- stddev=0.01))
- first_bias = tf.Variable(tf.zeros([first_filter_count]))
+ first_weights = tf.get_variable(
+ name='first_weights',
+ initializer=tf.truncated_normal_initializer(stddev=0.01),
+ shape=[first_filter_height, first_filter_width, 1, first_filter_count])
+ first_bias = tf.get_variable(
+ name='first_bias',
+ initializer=tf.zeros_initializer,
+ shape=[first_filter_count])
first_conv = tf.nn.conv2d(fingerprint_4d, first_weights, [1, 1, 1, 1],
'SAME') + first_bias
first_relu = tf.nn.relu(first_conv)
@@ -235,14 +273,17 @@ def create_conv_model(fingerprint_input, model_settings, is_training):
second_filter_width = 4
second_filter_height = 10
second_filter_count = 64
- second_weights = tf.Variable(
- tf.truncated_normal(
- [
- second_filter_height, second_filter_width, first_filter_count,
- second_filter_count
- ],
- stddev=0.01))
- second_bias = tf.Variable(tf.zeros([second_filter_count]))
+ second_weights = tf.get_variable(
+ name='second_weights',
+ initializer=tf.truncated_normal_initializer(stddev=0.01),
+ shape=[
+ second_filter_height, second_filter_width, first_filter_count,
+ second_filter_count
+ ])
+ second_bias = tf.get_variable(
+ name='second_bias',
+ initializer=tf.zeros_initializer,
+ shape=[second_filter_count])
second_conv = tf.nn.conv2d(max_pool, second_weights, [1, 1, 1, 1],
'SAME') + second_bias
second_relu = tf.nn.relu(second_conv)
@@ -259,10 +300,14 @@ def create_conv_model(fingerprint_input, model_settings, is_training):
flattened_second_conv = tf.reshape(second_dropout,
[-1, second_conv_element_count])
label_count = model_settings['label_count']
- final_fc_weights = tf.Variable(
- tf.truncated_normal(
- [second_conv_element_count, label_count], stddev=0.01))
- final_fc_bias = tf.Variable(tf.zeros([label_count]))
+ final_fc_weights = tf.get_variable(
+ name='final_fc_weights',
+ initializer=tf.truncated_normal_initializer,
+ shape=[second_conv_element_count, label_count])
+ final_fc_bias = tf.get_variable(
+ name='final_fc_bias',
+ initializer=tf.zeros_initializer,
+ shape=[label_count])
final_fc = tf.matmul(flattened_second_conv, final_fc_weights) + final_fc_bias
if is_training:
return final_fc, dropout_prob
@@ -318,7 +363,7 @@ def create_low_latency_conv_model(fingerprint_input, model_settings,
"""
if is_training:
dropout_prob = tf.placeholder(tf.float32, name='dropout_prob')
- input_frequency_size = model_settings['dct_coefficient_count']
+ input_frequency_size = model_settings['fingerprint_width']
input_time_size = model_settings['spectrogram_length']
fingerprint_4d = tf.reshape(fingerprint_input,
[-1, input_time_size, input_frequency_size, 1])
@@ -327,11 +372,14 @@ def create_low_latency_conv_model(fingerprint_input, model_settings,
first_filter_count = 186
first_filter_stride_x = 1
first_filter_stride_y = 1
- first_weights = tf.Variable(
- tf.truncated_normal(
- [first_filter_height, first_filter_width, 1, first_filter_count],
- stddev=0.01))
- first_bias = tf.Variable(tf.zeros([first_filter_count]))
+ first_weights = tf.get_variable(
+ name='first_weights',
+ initializer=tf.truncated_normal_initializer(stddev=0.01),
+ shape=[first_filter_height, first_filter_width, 1, first_filter_count])
+ first_bias = tf.get_variable(
+ name='first_bias',
+ initializer=tf.zeros_initializer,
+ shape=[first_filter_count])
first_conv = tf.nn.conv2d(fingerprint_4d, first_weights, [
1, first_filter_stride_y, first_filter_stride_x, 1
], 'VALID') + first_bias
@@ -351,30 +399,42 @@ def create_low_latency_conv_model(fingerprint_input, model_settings,
flattened_first_conv = tf.reshape(first_dropout,
[-1, first_conv_element_count])
first_fc_output_channels = 128
- first_fc_weights = tf.Variable(
- tf.truncated_normal(
- [first_conv_element_count, first_fc_output_channels], stddev=0.01))
- first_fc_bias = tf.Variable(tf.zeros([first_fc_output_channels]))
+ first_fc_weights = tf.get_variable(
+ name='first_fc_weights',
+ initializer=tf.truncated_normal_initializer(stddev=0.01),
+ shape=[first_conv_element_count, first_fc_output_channels])
+ first_fc_bias = tf.get_variable(
+ name='first_fc_bias',
+ initializer=tf.zeros_initializer,
+ shape=[first_fc_output_channels])
first_fc = tf.matmul(flattened_first_conv, first_fc_weights) + first_fc_bias
if is_training:
second_fc_input = tf.nn.dropout(first_fc, dropout_prob)
else:
second_fc_input = first_fc
second_fc_output_channels = 128
- second_fc_weights = tf.Variable(
- tf.truncated_normal(
- [first_fc_output_channels, second_fc_output_channels], stddev=0.01))
- second_fc_bias = tf.Variable(tf.zeros([second_fc_output_channels]))
+ second_fc_weights = tf.get_variable(
+ name='second_fc_weights',
+ initializer=tf.truncated_normal_initializer(stddev=0.01),
+ shape=[first_fc_output_channels, second_fc_output_channels])
+ second_fc_bias = tf.get_variable(
+ name='second_fc_bias',
+ initializer=tf.zeros_initializer,
+ shape=[second_fc_output_channels])
second_fc = tf.matmul(second_fc_input, second_fc_weights) + second_fc_bias
if is_training:
final_fc_input = tf.nn.dropout(second_fc, dropout_prob)
else:
final_fc_input = second_fc
label_count = model_settings['label_count']
- final_fc_weights = tf.Variable(
- tf.truncated_normal(
- [second_fc_output_channels, label_count], stddev=0.01))
- final_fc_bias = tf.Variable(tf.zeros([label_count]))
+ final_fc_weights = tf.get_variable(
+ name='final_fc_weights',
+ initializer=tf.truncated_normal_initializer(stddev=0.01),
+ shape=[second_fc_output_channels, label_count])
+ final_fc_bias = tf.get_variable(
+ name='final_fc_bias',
+ initializer=tf.zeros_initializer,
+ shape=[label_count])
final_fc = tf.matmul(final_fc_input, final_fc_weights) + final_fc_bias
if is_training:
return final_fc, dropout_prob
@@ -422,7 +482,7 @@ def create_low_latency_svdf_model(fingerprint_input, model_settings,
Args:
fingerprint_input: TensorFlow node that will output audio feature vectors.
The node is expected to produce a 2D Tensor of shape:
- [batch, model_settings['dct_coefficient_count'] *
+ [batch, model_settings['fingerprint_width'] *
model_settings['spectrogram_length']]
with the features corresponding to the same time slot arranged contiguously,
and the oldest slot at index [:, 0], and newest at [:, -1].
@@ -440,7 +500,7 @@ def create_low_latency_svdf_model(fingerprint_input, model_settings,
if is_training:
dropout_prob = tf.placeholder(tf.float32, name='dropout_prob')
- input_frequency_size = model_settings['dct_coefficient_count']
+ input_frequency_size = model_settings['fingerprint_width']
input_time_size = model_settings['spectrogram_length']
# Validation.
@@ -462,8 +522,11 @@ def create_low_latency_svdf_model(fingerprint_input, model_settings,
num_filters = rank * num_units
# Create the runtime memory: [num_filters, batch, input_time_size]
batch = 1
- memory = tf.Variable(tf.zeros([num_filters, batch, input_time_size]),
- trainable=False, name='runtime-memory')
+ memory = tf.get_variable(
+ initializer=tf.zeros_initializer,
+ shape=[num_filters, batch, input_time_size],
+ trainable=False,
+ name='runtime-memory')
# Determine the number of new frames in the input, such that we only operate
# on those. For training we do not use the memory, and thus use all frames
# provided in the input.
@@ -483,8 +546,10 @@ def create_low_latency_svdf_model(fingerprint_input, model_settings,
new_fingerprint_input = tf.expand_dims(new_fingerprint_input, 2)
# Create the frequency filters.
- weights_frequency = tf.Variable(
- tf.truncated_normal([input_frequency_size, num_filters], stddev=0.01))
+ weights_frequency = tf.get_variable(
+ name='weights_frequency',
+ initializer=tf.truncated_normal_initializer(stddev=0.01),
+ shape=[input_frequency_size, num_filters])
# Expand to add input channels dimensions.
# weights_frequency: [input_frequency_size, 1, num_filters]
weights_frequency = tf.expand_dims(weights_frequency, 1)
@@ -506,8 +571,10 @@ def create_low_latency_svdf_model(fingerprint_input, model_settings,
activations_time = new_memory
# Create the time filters.
- weights_time = tf.Variable(
- tf.truncated_normal([num_filters, input_time_size], stddev=0.01))
+ weights_time = tf.get_variable(
+ name='weights_time',
+ initializer=tf.truncated_normal_initializer(stddev=0.01),
+ shape=[num_filters, input_time_size])
# Apply the time filter on the outputs of the feature filters.
# weights_time: [num_filters, input_time_size, 1]
# outputs: [num_filters, batch, 1]
@@ -524,7 +591,8 @@ def create_low_latency_svdf_model(fingerprint_input, model_settings,
units_output = tf.transpose(units_output)
# Appy bias.
- bias = tf.Variable(tf.zeros([num_units]))
+ bias = tf.get_variable(
+ name='bias', initializer=tf.zeros_initializer, shape=[num_units])
first_bias = tf.nn.bias_add(units_output, bias)
# Relu.
@@ -536,31 +604,135 @@ def create_low_latency_svdf_model(fingerprint_input, model_settings,
first_dropout = first_relu
first_fc_output_channels = 256
- first_fc_weights = tf.Variable(
- tf.truncated_normal([num_units, first_fc_output_channels], stddev=0.01))
- first_fc_bias = tf.Variable(tf.zeros([first_fc_output_channels]))
+ first_fc_weights = tf.get_variable(
+ name='first_fc_weights',
+ initializer=tf.truncated_normal_initializer(stddev=0.01),
+ shape=[num_units, first_fc_output_channels])
+ first_fc_bias = tf.get_variable(
+ name='first_fc_bias',
+ initializer=tf.zeros_initializer,
+ shape=[first_fc_output_channels])
first_fc = tf.matmul(first_dropout, first_fc_weights) + first_fc_bias
if is_training:
second_fc_input = tf.nn.dropout(first_fc, dropout_prob)
else:
second_fc_input = first_fc
second_fc_output_channels = 256
- second_fc_weights = tf.Variable(
- tf.truncated_normal(
- [first_fc_output_channels, second_fc_output_channels], stddev=0.01))
- second_fc_bias = tf.Variable(tf.zeros([second_fc_output_channels]))
+ second_fc_weights = tf.get_variable(
+ name='second_fc_weights',
+ initializer=tf.truncated_normal_initializer(stddev=0.01),
+ shape=[first_fc_output_channels, second_fc_output_channels])
+ second_fc_bias = tf.get_variable(
+ name='second_fc_bias',
+ initializer=tf.zeros_initializer,
+ shape=[second_fc_output_channels])
second_fc = tf.matmul(second_fc_input, second_fc_weights) + second_fc_bias
if is_training:
final_fc_input = tf.nn.dropout(second_fc, dropout_prob)
else:
final_fc_input = second_fc
label_count = model_settings['label_count']
- final_fc_weights = tf.Variable(
- tf.truncated_normal(
- [second_fc_output_channels, label_count], stddev=0.01))
- final_fc_bias = tf.Variable(tf.zeros([label_count]))
+ final_fc_weights = tf.get_variable(
+ name='final_fc_weights',
+ initializer=tf.truncated_normal(stddev=0.01),
+ shape=[second_fc_output_channels, label_count])
+ final_fc_bias = tf.get_variable(
+ name='final_fc_bias',
+ initializer=tf.zeros_initializer,
+ shape=[label_count])
final_fc = tf.matmul(final_fc_input, final_fc_weights) + final_fc_bias
if is_training:
return final_fc, dropout_prob
else:
return final_fc
+
+
+def create_tiny_conv_model(fingerprint_input, model_settings, is_training):
+ """Builds a convolutional model aimed at microcontrollers.
+
+ Devices like DSPs and microcontrollers can have very small amounts of
+ memory and limited processing power. This model is designed to use less
+ than 20KB of working RAM, and fit within 32KB of read-only (flash) memory.
+
+ Here's the layout of the graph:
+
+ (fingerprint_input)
+ v
+ [Conv2D]<-(weights)
+ v
+ [BiasAdd]<-(bias)
+ v
+ [Relu]
+ v
+ [MatMul]<-(weights)
+ v
+ [BiasAdd]<-(bias)
+ v
+
+ This doesn't produce particularly accurate results, but it's designed to be
+ used as the first stage of a pipeline, running on a low-energy piece of
+ hardware that can always be on, and then wake higher-power chips when a
+ possible utterance has been found, so that more accurate analysis can be done.
+
+ During training, a dropout node is introduced after the relu, controlled by a
+ placeholder.
+
+ Args:
+ fingerprint_input: TensorFlow node that will output audio feature vectors.
+ model_settings: Dictionary of information about the model.
+ is_training: Whether the model is going to be used for training.
+
+ Returns:
+ TensorFlow node outputting logits results, and optionally a dropout
+ placeholder.
+ """
+ if is_training:
+ dropout_prob = tf.placeholder(tf.float32, name='dropout_prob')
+ input_frequency_size = model_settings['fingerprint_width']
+ input_time_size = model_settings['spectrogram_length']
+ fingerprint_4d = tf.reshape(fingerprint_input,
+ [-1, input_time_size, input_frequency_size, 1])
+ first_filter_width = 8
+ first_filter_height = 10
+ first_filter_count = 8
+ first_weights = tf.get_variable(
+ name='first_weights',
+ initializer=tf.truncated_normal_initializer(stddev=0.01),
+ shape=[first_filter_height, first_filter_width, 1, first_filter_count])
+ first_bias = tf.get_variable(
+ name='first_bias',
+ initializer=tf.zeros_initializer,
+ shape=[first_filter_count])
+ first_conv_stride_x = 2
+ first_conv_stride_y = 2
+ first_conv = tf.nn.conv2d(fingerprint_4d, first_weights,
+ [1, first_conv_stride_y, first_conv_stride_x, 1],
+ 'SAME') + first_bias
+ first_relu = tf.nn.relu(first_conv)
+ if is_training:
+ first_dropout = tf.nn.dropout(first_relu, dropout_prob)
+ else:
+ first_dropout = first_relu
+ first_dropout_shape = first_dropout.get_shape()
+ first_dropout_output_width = first_dropout_shape[2]
+ first_dropout_output_height = first_dropout_shape[1]
+ first_dropout_element_count = int(
+ first_dropout_output_width * first_dropout_output_height *
+ first_filter_count)
+ flattened_first_dropout = tf.reshape(first_dropout,
+ [-1, first_dropout_element_count])
+ label_count = model_settings['label_count']
+ final_fc_weights = tf.get_variable(
+ name='final_fc_weights',
+ initializer=tf.truncated_normal_initializer(stddev=0.01),
+ shape=[first_dropout_element_count, label_count])
+ final_fc_bias = tf.get_variable(
+ name='final_fc_bias',
+ initializer=tf.zeros_initializer,
+ shape=[label_count])
+ final_fc = (
+ tf.matmul(flattened_first_dropout, final_fc_weights) + final_fc_bias)
+ if is_training:
+ return final_fc, dropout_prob
+ else:
+ return final_fc
diff --git a/tensorflow/examples/speech_commands/models_test.py b/tensorflow/examples/speech_commands/models_test.py
index 80c795367f..0c373967ed 100644
--- a/tensorflow/examples/speech_commands/models_test.py
+++ b/tensorflow/examples/speech_commands/models_test.py
@@ -26,12 +26,29 @@ from tensorflow.python.platform import test
class ModelsTest(test.TestCase):
+ def _modelSettings(self):
+ return models.prepare_model_settings(
+ label_count=10,
+ sample_rate=16000,
+ clip_duration_ms=1000,
+ window_size_ms=20,
+ window_stride_ms=10,
+ feature_bin_count=40,
+ preprocess="mfcc")
+
def testPrepareModelSettings(self):
self.assertIsNotNone(
- models.prepare_model_settings(10, 16000, 1000, 20, 10, 40))
+ models.prepare_model_settings(
+ label_count=10,
+ sample_rate=16000,
+ clip_duration_ms=1000,
+ window_size_ms=20,
+ window_stride_ms=10,
+ feature_bin_count=40,
+ preprocess="mfcc"))
def testCreateModelConvTraining(self):
- model_settings = models.prepare_model_settings(10, 16000, 1000, 20, 10, 40)
+ model_settings = self._modelSettings()
with self.test_session() as sess:
fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
logits, dropout_prob = models.create_model(fingerprint_input,
@@ -42,7 +59,7 @@ class ModelsTest(test.TestCase):
self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_prob.name))
def testCreateModelConvInference(self):
- model_settings = models.prepare_model_settings(10, 16000, 1000, 20, 10, 40)
+ model_settings = self._modelSettings()
with self.test_session() as sess:
fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
logits = models.create_model(fingerprint_input, model_settings, "conv",
@@ -51,7 +68,7 @@ class ModelsTest(test.TestCase):
self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name))
def testCreateModelLowLatencyConvTraining(self):
- model_settings = models.prepare_model_settings(10, 16000, 1000, 20, 10, 40)
+ model_settings = self._modelSettings()
with self.test_session() as sess:
fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
logits, dropout_prob = models.create_model(
@@ -62,7 +79,7 @@ class ModelsTest(test.TestCase):
self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_prob.name))
def testCreateModelFullyConnectedTraining(self):
- model_settings = models.prepare_model_settings(10, 16000, 1000, 20, 10, 40)
+ model_settings = self._modelSettings()
with self.test_session() as sess:
fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
logits, dropout_prob = models.create_model(
@@ -73,7 +90,7 @@ class ModelsTest(test.TestCase):
self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_prob.name))
def testCreateModelBadArchitecture(self):
- model_settings = models.prepare_model_settings(10, 16000, 1000, 20, 10, 40)
+ model_settings = self._modelSettings()
with self.test_session():
fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
with self.assertRaises(Exception) as e:
@@ -81,6 +98,17 @@ class ModelsTest(test.TestCase):
"bad_architecture", True)
self.assertTrue("not recognized" in str(e.exception))
+ def testCreateModelTinyConvTraining(self):
+ model_settings = self._modelSettings()
+ with self.test_session() as sess:
+ fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
+ logits, dropout_prob = models.create_model(
+ fingerprint_input, model_settings, "tiny_conv", True)
+ self.assertIsNotNone(logits)
+ self.assertIsNotNone(dropout_prob)
+ self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name))
+ self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_prob.name))
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/examples/speech_commands/train.py b/tensorflow/examples/speech_commands/train.py
index fc28eb0631..eca34f8812 100644
--- a/tensorflow/examples/speech_commands/train.py
+++ b/tensorflow/examples/speech_commands/train.py
@@ -98,12 +98,12 @@ def main(_):
model_settings = models.prepare_model_settings(
len(input_data.prepare_words_list(FLAGS.wanted_words.split(','))),
FLAGS.sample_rate, FLAGS.clip_duration_ms, FLAGS.window_size_ms,
- FLAGS.window_stride_ms, FLAGS.dct_coefficient_count)
+ FLAGS.window_stride_ms, FLAGS.feature_bin_count, FLAGS.preprocess)
audio_processor = input_data.AudioProcessor(
- FLAGS.data_url, FLAGS.data_dir, FLAGS.silence_percentage,
- FLAGS.unknown_percentage,
+ FLAGS.data_url, FLAGS.data_dir,
+ FLAGS.silence_percentage, FLAGS.unknown_percentage,
FLAGS.wanted_words.split(','), FLAGS.validation_percentage,
- FLAGS.testing_percentage, model_settings)
+ FLAGS.testing_percentage, model_settings, FLAGS.summaries_dir)
fingerprint_size = model_settings['fingerprint_size']
label_count = model_settings['label_count']
time_shift_samples = int((FLAGS.time_shift_ms * FLAGS.sample_rate) / 1000)
@@ -122,8 +122,25 @@ def main(_):
'lists, but are %d and %d long instead' % (len(training_steps_list),
len(learning_rates_list)))
- fingerprint_input = tf.placeholder(
+ input_placeholder = tf.placeholder(
tf.float32, [None, fingerprint_size], name='fingerprint_input')
+ if FLAGS.quantize:
+ # TODO(petewarden): These values have been derived from the observed ranges
+ # of spectrogram and MFCC inputs. If the preprocessing pipeline changes,
+ # they may need to be updated.
+ if FLAGS.preprocess == 'average':
+ fingerprint_min = 0.0
+ fingerprint_max = 2048.0
+ elif FLAGS.preprocess == 'mfcc':
+ fingerprint_min = -247.0
+ fingerprint_max = 30.0
+ else:
+ raise Exception('Unknown preprocess mode "%s" (should be "mfcc" or'
+ ' "average")' % (FLAGS.preprocess))
+ fingerprint_input = tf.fake_quant_with_min_max_args(
+ input_placeholder, fingerprint_min, fingerprint_max)
+ else:
+ fingerprint_input = input_placeholder
logits, dropout_prob = models.create_model(
fingerprint_input,
@@ -146,7 +163,8 @@ def main(_):
with tf.name_scope('cross_entropy'):
cross_entropy_mean = tf.losses.sparse_softmax_cross_entropy(
labels=ground_truth_input, logits=logits)
- tf.summary.scalar('cross_entropy', cross_entropy_mean)
+ if FLAGS.quantize:
+ tf.contrib.quantize.create_training_graph(quant_delay=0)
with tf.name_scope('train'), tf.control_dependencies(control_dependencies):
learning_rate_input = tf.placeholder(
tf.float32, [], name='learning_rate_input')
@@ -157,7 +175,9 @@ def main(_):
confusion_matrix = tf.confusion_matrix(
ground_truth_input, predicted_indices, num_classes=label_count)
evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
- tf.summary.scalar('accuracy', evaluation_step)
+ with tf.get_default_graph().name_scope('eval'):
+ tf.summary.scalar('cross_entropy', cross_entropy_mean)
+ tf.summary.scalar('accuracy', evaluation_step)
global_step = tf.train.get_or_create_global_step()
increment_global_step = tf.assign(global_step, global_step + 1)
@@ -165,7 +185,7 @@ def main(_):
saver = tf.train.Saver(tf.global_variables())
# Merge all the summaries and write them out to /tmp/retrain_logs (by default)
- merged_summaries = tf.summary.merge_all()
+ merged_summaries = tf.summary.merge_all(scope='eval')
train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train',
sess.graph)
validation_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/validation')
@@ -207,8 +227,11 @@ def main(_):
# Run the graph with this batch of training data.
train_summary, train_accuracy, cross_entropy_value, _, _ = sess.run(
[
- merged_summaries, evaluation_step, cross_entropy_mean, train_step,
- increment_global_step
+ merged_summaries,
+ evaluation_step,
+ cross_entropy_mean,
+ train_step,
+ increment_global_step,
],
feed_dict={
fingerprint_input: train_fingerprints,
@@ -364,10 +387,11 @@ if __name__ == '__main__':
default=10.0,
help='How far to move in time between spectogram timeslices.',)
parser.add_argument(
- '--dct_coefficient_count',
+ '--feature_bin_count',
type=int,
default=40,
- help='How many bins to use for the MFCC fingerprint',)
+ help='How many bins to use for the MFCC fingerprint',
+ )
parser.add_argument(
'--how_many_training_steps',
type=str,
@@ -423,6 +447,16 @@ if __name__ == '__main__':
type=bool,
default=False,
help='Whether to check for invalid numbers during processing')
+ parser.add_argument(
+ '--quantize',
+ type=bool,
+ default=False,
+ help='Whether to train the model for eight-bit deployment')
+ parser.add_argument(
+ '--preprocess',
+ type=str,
+ default='mfcc',
+ help='Spectrogram processing mode. Can be "mfcc" or "average"')
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/go/graph.go b/tensorflow/go/graph.go
index 08943a527c..32a77550ee 100644
--- a/tensorflow/go/graph.go
+++ b/tensorflow/go/graph.go
@@ -177,7 +177,14 @@ type OpSpec struct {
// being added.
ControlDependencies []*Operation
- // Other possible fields: Device, ColocateWith.
+ // The device on which the operation should be executed.
+ // If omitted, an appropriate device will automatically be selected.
+ //
+ // For example, if set of "/device:GPU:0", then the operation will
+ // execute on GPU #0.
+ Device string
+
+ // Other possible fields: ColocateWith.
}
// AddOperation adds an operation to g.
@@ -225,6 +232,11 @@ func (g *Graph) AddOperation(args OpSpec) (*Operation, error) {
return nil, fmt.Errorf("%v (memory will be leaked)", err)
}
}
+ if len(args.Device) > 0 {
+ cdevice := C.CString(args.Device)
+ C.TF_SetDevice(cdesc, cdevice)
+ C.free(unsafe.Pointer(cdevice))
+ }
c := C.TF_FinishOperation(cdesc, status.c)
if err := status.Err(); err != nil {
return nil, err
diff --git a/tensorflow/go/op/scope.go b/tensorflow/go/op/scope.go
index 13de4294dc..ac39808d83 100644
--- a/tensorflow/go/op/scope.go
+++ b/tensorflow/go/op/scope.go
@@ -37,6 +37,7 @@ type Scope struct {
namemap map[string]int
namespace string
controlDependencies []*tf.Operation
+ device string
err *scopeErr
}
@@ -82,6 +83,7 @@ func (s *Scope) AddOperation(args tf.OpSpec) *tf.Operation {
args.Name = s.namespace + "/" + args.Name
}
args.ControlDependencies = append(args.ControlDependencies, s.controlDependencies...)
+ args.Device = s.device
op, err := s.graph.AddOperation(args)
if err != nil {
s.UpdateErr(args.Type, err)
@@ -98,10 +100,12 @@ func (s *Scope) SubScope(namespace string) *Scope {
namespace = s.namespace + "/" + namespace
}
return &Scope{
- graph: s.graph,
- namemap: make(map[string]int),
- namespace: namespace,
- err: s.err,
+ graph: s.graph,
+ namemap: make(map[string]int),
+ namespace: namespace,
+ controlDependencies: s.controlDependencies,
+ device: s.device,
+ err: s.err,
}
}
@@ -123,6 +127,25 @@ func (s *Scope) WithControlDependencies(ops ...*tf.Operation) *Scope {
namemap: s.namemap,
namespace: s.namespace,
controlDependencies: deps,
+ device: s.device,
+ err: s.err,
+ }
+}
+
+// WithDevice returns a new Scope which will cause all operations added to the
+// graph to execute on devices that match the provided device specification.
+//
+// For example, WithDevice("/device:GPU:0") will cause operations added to
+// the graph to execute on GPU #0.
+//
+// An empty string removes any device restrictions.
+func (s *Scope) WithDevice(device string) *Scope {
+ return &Scope{
+ graph: s.graph,
+ namemap: s.namemap,
+ namespace: s.namespace,
+ controlDependencies: s.controlDependencies,
+ device: device,
err: s.err,
}
}
diff --git a/tensorflow/go/op/scope_test.go b/tensorflow/go/op/scope_test.go
index b58a61de98..be7b0ad892 100644
--- a/tensorflow/go/op/scope_test.go
+++ b/tensorflow/go/op/scope_test.go
@@ -112,6 +112,21 @@ func TestControlDependencies(t *testing.T) {
}
}
+func TestDevice(t *testing.T) {
+ s := NewScope()
+ matrix := Const(s, [][]float32{{3.0}})
+ s = s.WithDevice("/device:GPU:0")
+ square := MatMul(s.SubScope("square"), matrix, matrix)
+ s = s.WithDevice("")
+ cube := MatMul(s.SubScope("cube"), square, matrix)
+ if got, want := square.Op.Device(), "/device:GPU:0"; got != want {
+ t.Errorf("Got %q, want %q", got, want)
+ }
+ if got, want := cube.Op.Device(), ""; got != want {
+ t.Errorf("Got %q, want %q", got, want)
+ }
+}
+
func TestScopeFinalize(t *testing.T) {
var (
root = NewScope()
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index d20e88e95b..f49e1cecaf 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -3069,6 +3069,152 @@ func InplaceSub(scope *Scope, x tf.Output, i tf.Output, v tf.Output) (y tf.Outpu
return op.Output(0)
}
+// Updates specified rows with values in `v`.
+//
+// Computes `x[i, :] = v; return x`.
+//
+// Arguments:
+// x: A tensor of type `T`.
+// i: A vector. Indices into the left-most dimension of `x`.
+// v: A `Tensor` of type T. Same dimension sizes as x except the first dimension, which must be the same as i's size.
+//
+// Returns A `Tensor` of type T. An alias of `x`. The content of `y` is undefined if there are duplicates in `i`.
+func InplaceUpdate(scope *Scope, x tf.Output, i tf.Output, v tf.Output) (y tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "InplaceUpdate",
+ Input: []tf.Input{
+ x, i, v,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Makes a copy of `x`.
+//
+// Arguments:
+// x: The source tensor of type `T`.
+//
+// Returns y: A `Tensor` of type `T`. A copy of `x`. Guaranteed that `y`
+// is not an alias of `x`.
+func DeepCopy(scope *Scope, x tf.Output) (y tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "DeepCopy",
+ Input: []tf.Input{
+ x,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// PackAttr is an optional argument to Pack.
+type PackAttr func(optionalAttr)
+
+// PackAxis sets the optional axis attribute to value.
+//
+// value: Dimension along which to pack. Negative values wrap around, so the
+// valid range is `[-(R+1), R+1)`.
+// If not specified, defaults to 0
+func PackAxis(value int64) PackAttr {
+ return func(m optionalAttr) {
+ m["axis"] = value
+ }
+}
+
+// Packs a list of `N` rank-`R` tensors into one rank-`(R+1)` tensor.
+//
+// Packs the `N` tensors in `values` into a tensor with rank one higher than each
+// tensor in `values`, by packing them along the `axis` dimension.
+// Given a list of tensors of shape `(A, B, C)`;
+//
+// if `axis == 0` then the `output` tensor will have the shape `(N, A, B, C)`.
+// if `axis == 1` then the `output` tensor will have the shape `(A, N, B, C)`.
+// Etc.
+//
+// For example:
+//
+// ```
+// # 'x' is [1, 4]
+// # 'y' is [2, 5]
+// # 'z' is [3, 6]
+// pack([x, y, z]) => [[1, 4], [2, 5], [3, 6]] # Pack along first dim.
+// pack([x, y, z], axis=1) => [[1, 2, 3], [4, 5, 6]]
+// ```
+//
+// This is the opposite of `unpack`.
+//
+// Arguments:
+// values: Must be of same shape and type.
+//
+// Returns The packed tensor.
+func Pack(scope *Scope, values []tf.Output, optional ...PackAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "Pack",
+ Input: []tf.Input{
+ tf.OutputList(values),
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Concatenates a list of `N` tensors along the first dimension.
+//
+// The input tensors are all required to have size 1 in the first dimension.
+//
+// For example:
+//
+// ```
+// # 'x' is [[1, 4]]
+// # 'y' is [[2, 5]]
+// # 'z' is [[3, 6]]
+// parallel_concat([x, y, z]) => [[1, 4], [2, 5], [3, 6]] # Pack along first dim.
+// ```
+//
+// The difference between concat and parallel_concat is that concat requires all
+// of the inputs be computed before the operation will begin but doesn't require
+// that the input shapes be known during graph construction. Parallel concat
+// will copy pieces of the input into the output as they become available, in
+// some situations this can provide a performance benefit.
+//
+// Arguments:
+// values: Tensors to be concatenated. All must have size 1 in the first dimension
+// and same shape.
+// shape: the final shape of the result; should be equal to the shapes of any input
+// but with the number of input values in the first dimension.
+//
+// Returns The concatenated tensor.
+func ParallelConcat(scope *Scope, values []tf.Output, shape tf.Shape) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"shape": shape}
+ opspec := tf.OpSpec{
+ Type: "ParallelConcat",
+ Input: []tf.Input{
+ tf.OutputList(values),
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Computes the mean along sparse segments of a tensor.
//
// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
@@ -3121,6 +3267,57 @@ func StackPopV2(scope *Scope, handle tf.Output, elem_type tf.DataType) (elem tf.
return op.Output(0)
}
+// Computes the sum along sparse segments of a tensor.
+//
+// Like `SparseSegmentSum`, but allows missing ids in `segment_ids`. If an id is
+// misisng, the `output` tensor at that position will be zeroed.
+//
+// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
+// segments.
+//
+// For example:
+//
+// ```python
+// c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]])
+//
+// tf.sparse_segment_sum_with_num_segments(
+// c, tf.constant([0, 1]), tf.constant([0, 0]), num_segments=3)
+// # => [[0 0 0 0]
+// # [0 0 0 0]
+// # [0 0 0 0]]
+//
+// tf.sparse_segment_sum_with_num_segments(c,
+// tf.constant([0, 1]),
+// tf.constant([0, 2],
+// num_segments=4))
+// # => [[ 1 2 3 4]
+// # [ 0 0 0 0]
+// # [-1 -2 -3 -4]
+// # [ 0 0 0 0]]
+// ```
+//
+// Arguments:
+//
+// indices: A 1-D tensor. Has same rank as `segment_ids`.
+// segment_ids: A 1-D tensor. Values should be sorted and can be repeated.
+// num_segments: Should equal the number of distinct segment IDs.
+//
+// Returns Has same shape as data, except for dimension 0 which
+// has size `num_segments`.
+func SparseSegmentSumWithNumSegments(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "SparseSegmentSumWithNumSegments",
+ Input: []tf.Input{
+ data, indices, segment_ids, num_segments,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// PreventGradientAttr is an optional argument to PreventGradient.
type PreventGradientAttr func(optionalAttr)
@@ -6071,53 +6268,6 @@ func MutexV2(scope *Scope, optional ...MutexV2Attr) (resource tf.Output) {
return op.Output(0)
}
-// AvgPool3DAttr is an optional argument to AvgPool3D.
-type AvgPool3DAttr func(optionalAttr)
-
-// AvgPool3DDataFormat sets the optional data_format attribute to value.
-//
-// value: The data format of the input and output data. With the
-// default format "NDHWC", the data is stored in the order of:
-// [batch, in_depth, in_height, in_width, in_channels].
-// Alternatively, the format could be "NCDHW", the data storage order is:
-// [batch, in_channels, in_depth, in_height, in_width].
-// If not specified, defaults to "NDHWC"
-func AvgPool3DDataFormat(value string) AvgPool3DAttr {
- return func(m optionalAttr) {
- m["data_format"] = value
- }
-}
-
-// Performs 3D average pooling on the input.
-//
-// Arguments:
-// input: Shape `[batch, depth, rows, cols, channels]` tensor to pool over.
-// ksize: 1-D tensor of length 5. The size of the window for each dimension of
-// the input tensor. Must have `ksize[0] = ksize[4] = 1`.
-// strides: 1-D tensor of length 5. The stride of the sliding window for each
-// dimension of `input`. Must have `strides[0] = strides[4] = 1`.
-// padding: The type of padding algorithm to use.
-//
-// Returns The average pooled output tensor.
-func AvgPool3D(scope *Scope, input tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPool3DAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "AvgPool3D",
- Input: []tf.Input{
- input,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Returns element-wise remainder of division. This emulates C semantics in that
//
// the result here is consistent with a truncating divide. E.g.
@@ -8114,27 +8264,6 @@ func CollectiveBcastSend(scope *Scope, input tf.Output, group_size int64, group_
return op.Output(0)
}
-// Makes a copy of `x`.
-//
-// Arguments:
-// x: The source tensor of type `T`.
-//
-// Returns y: A `Tensor` of type `T`. A copy of `x`. Guaranteed that `y`
-// is not an alias of `x`.
-func DeepCopy(scope *Scope, x tf.Output) (y tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "DeepCopy",
- Input: []tf.Input{
- x,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Split a `SparseTensor` into `num_split` tensors along one dimension.
//
// If the `shape[split_dim]` is not an integer multiple of `num_split`. Slices
@@ -13876,6 +14005,83 @@ func Minimum(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
return op.Output(0)
}
+// MfccAttr is an optional argument to Mfcc.
+type MfccAttr func(optionalAttr)
+
+// MfccUpperFrequencyLimit sets the optional upper_frequency_limit attribute to value.
+//
+// value: The highest frequency to use when calculating the
+// ceptstrum.
+// If not specified, defaults to 4000
+func MfccUpperFrequencyLimit(value float32) MfccAttr {
+ return func(m optionalAttr) {
+ m["upper_frequency_limit"] = value
+ }
+}
+
+// MfccLowerFrequencyLimit sets the optional lower_frequency_limit attribute to value.
+//
+// value: The lowest frequency to use when calculating the
+// ceptstrum.
+// If not specified, defaults to 20
+func MfccLowerFrequencyLimit(value float32) MfccAttr {
+ return func(m optionalAttr) {
+ m["lower_frequency_limit"] = value
+ }
+}
+
+// MfccFilterbankChannelCount sets the optional filterbank_channel_count attribute to value.
+//
+// value: Resolution of the Mel bank used internally.
+// If not specified, defaults to 40
+func MfccFilterbankChannelCount(value int64) MfccAttr {
+ return func(m optionalAttr) {
+ m["filterbank_channel_count"] = value
+ }
+}
+
+// MfccDctCoefficientCount sets the optional dct_coefficient_count attribute to value.
+//
+// value: How many output channels to produce per time slice.
+// If not specified, defaults to 13
+func MfccDctCoefficientCount(value int64) MfccAttr {
+ return func(m optionalAttr) {
+ m["dct_coefficient_count"] = value
+ }
+}
+
+// Transforms a spectrogram into a form that's useful for speech recognition.
+//
+// Mel Frequency Cepstral Coefficients are a way of representing audio data that's
+// been effective as an input feature for machine learning. They are created by
+// taking the spectrum of a spectrogram (a 'cepstrum'), and discarding some of the
+// higher frequencies that are less significant to the human ear. They have a long
+// history in the speech recognition world, and https://en.wikipedia.org/wiki/Mel-frequency_cepstrum
+// is a good resource to learn more.
+//
+// Arguments:
+// spectrogram: Typically produced by the Spectrogram op, with magnitude_squared
+// set to true.
+// sample_rate: How many samples per second the source audio used.
+func Mfcc(scope *Scope, spectrogram tf.Output, sample_rate tf.Output, optional ...MfccAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "Mfcc",
+ Input: []tf.Input{
+ spectrogram, sample_rate,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// AudioSummaryAttr is an optional argument to AudioSummary.
type AudioSummaryAttr func(optionalAttr)
@@ -14294,65 +14500,6 @@ func TensorArraySplitV2(scope *Scope, handle tf.Output, value tf.Output, lengths
return op.Output(0)
}
-// PackAttr is an optional argument to Pack.
-type PackAttr func(optionalAttr)
-
-// PackAxis sets the optional axis attribute to value.
-//
-// value: Dimension along which to pack. Negative values wrap around, so the
-// valid range is `[-(R+1), R+1)`.
-// If not specified, defaults to 0
-func PackAxis(value int64) PackAttr {
- return func(m optionalAttr) {
- m["axis"] = value
- }
-}
-
-// Packs a list of `N` rank-`R` tensors into one rank-`(R+1)` tensor.
-//
-// Packs the `N` tensors in `values` into a tensor with rank one higher than each
-// tensor in `values`, by packing them along the `axis` dimension.
-// Given a list of tensors of shape `(A, B, C)`;
-//
-// if `axis == 0` then the `output` tensor will have the shape `(N, A, B, C)`.
-// if `axis == 1` then the `output` tensor will have the shape `(A, N, B, C)`.
-// Etc.
-//
-// For example:
-//
-// ```
-// # 'x' is [1, 4]
-// # 'y' is [2, 5]
-// # 'z' is [3, 6]
-// pack([x, y, z]) => [[1, 4], [2, 5], [3, 6]] # Pack along first dim.
-// pack([x, y, z], axis=1) => [[1, 2, 3], [4, 5, 6]]
-// ```
-//
-// This is the opposite of `unpack`.
-//
-// Arguments:
-// values: Must be of same shape and type.
-//
-// Returns The packed tensor.
-func Pack(scope *Scope, values []tf.Output, optional ...PackAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "Pack",
- Input: []tf.Input{
- tf.OutputList(values),
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Reorders a SparseTensor into the canonical, row-major ordering.
//
// Note that by convention, all sparse ops preserve the canonical ordering along
@@ -15012,30 +15159,6 @@ func Sigmoid(scope *Scope, x tf.Output) (y tf.Output) {
return op.Output(0)
}
-// Updates specified rows with values in `v`.
-//
-// Computes `x[i, :] = v; return x`.
-//
-// Arguments:
-// x: A tensor of type `T`.
-// i: A vector. Indices into the left-most dimension of `x`.
-// v: A `Tensor` of type T. Same dimension sizes as x except the first dimension, which must be the same as i's size.
-//
-// Returns A `Tensor` of type T. An alias of `x`. The content of `y` is undefined if there are duplicates in `i`.
-func InplaceUpdate(scope *Scope, x tf.Output, i tf.Output, v tf.Output) (y tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "InplaceUpdate",
- Input: []tf.Input{
- x, i, v,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// FusedBatchNormAttr is an optional argument to FusedBatchNorm.
type FusedBatchNormAttr func(optionalAttr)
@@ -20445,83 +20568,6 @@ func QuantizedAdd(scope *Scope, x tf.Output, y tf.Output, min_x tf.Output, max_x
return op.Output(0), op.Output(1), op.Output(2)
}
-// MfccAttr is an optional argument to Mfcc.
-type MfccAttr func(optionalAttr)
-
-// MfccUpperFrequencyLimit sets the optional upper_frequency_limit attribute to value.
-//
-// value: The highest frequency to use when calculating the
-// ceptstrum.
-// If not specified, defaults to 4000
-func MfccUpperFrequencyLimit(value float32) MfccAttr {
- return func(m optionalAttr) {
- m["upper_frequency_limit"] = value
- }
-}
-
-// MfccLowerFrequencyLimit sets the optional lower_frequency_limit attribute to value.
-//
-// value: The lowest frequency to use when calculating the
-// ceptstrum.
-// If not specified, defaults to 20
-func MfccLowerFrequencyLimit(value float32) MfccAttr {
- return func(m optionalAttr) {
- m["lower_frequency_limit"] = value
- }
-}
-
-// MfccFilterbankChannelCount sets the optional filterbank_channel_count attribute to value.
-//
-// value: Resolution of the Mel bank used internally.
-// If not specified, defaults to 40
-func MfccFilterbankChannelCount(value int64) MfccAttr {
- return func(m optionalAttr) {
- m["filterbank_channel_count"] = value
- }
-}
-
-// MfccDctCoefficientCount sets the optional dct_coefficient_count attribute to value.
-//
-// value: How many output channels to produce per time slice.
-// If not specified, defaults to 13
-func MfccDctCoefficientCount(value int64) MfccAttr {
- return func(m optionalAttr) {
- m["dct_coefficient_count"] = value
- }
-}
-
-// Transforms a spectrogram into a form that's useful for speech recognition.
-//
-// Mel Frequency Cepstral Coefficients are a way of representing audio data that's
-// been effective as an input feature for machine learning. They are created by
-// taking the spectrum of a spectrogram (a 'cepstrum'), and discarding some of the
-// higher frequencies that are less significant to the human ear. They have a long
-// history in the speech recognition world, and https://en.wikipedia.org/wiki/Mel-frequency_cepstrum
-// is a good resource to learn more.
-//
-// Arguments:
-// spectrogram: Typically produced by the Spectrogram op, with magnitude_squared
-// set to true.
-// sample_rate: How many samples per second the source audio used.
-func Mfcc(scope *Scope, spectrogram tf.Output, sample_rate tf.Output, optional ...MfccAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "Mfcc",
- Input: []tf.Input{
- spectrogram, sample_rate,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Given a quantized tensor described by (input, input_min, input_max), outputs a
//
// range that covers the actual values present in that tensor. This op is
@@ -25802,57 +25848,6 @@ func CacheDataset(scope *Scope, input_dataset tf.Output, filename tf.Output, out
return op.Output(0)
}
-// Computes the sum along sparse segments of a tensor.
-//
-// Like `SparseSegmentSum`, but allows missing ids in `segment_ids`. If an id is
-// misisng, the `output` tensor at that position will be zeroed.
-//
-// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-// segments.
-//
-// For example:
-//
-// ```python
-// c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]])
-//
-// tf.sparse_segment_sum_with_num_segments(
-// c, tf.constant([0, 1]), tf.constant([0, 0]), num_segments=3)
-// # => [[0 0 0 0]
-// # [0 0 0 0]
-// # [0 0 0 0]]
-//
-// tf.sparse_segment_sum_with_num_segments(c,
-// tf.constant([0, 1]),
-// tf.constant([0, 2],
-// num_segments=4))
-// # => [[ 1 2 3 4]
-// # [ 0 0 0 0]
-// # [-1 -2 -3 -4]
-// # [ 0 0 0 0]]
-// ```
-//
-// Arguments:
-//
-// indices: A 1-D tensor. Has same rank as `segment_ids`.
-// segment_ids: A 1-D tensor. Values should be sorted and can be repeated.
-// num_segments: Should equal the number of distinct segment IDs.
-//
-// Returns Has same shape as data, except for dimension 0 which
-// has size `num_segments`.
-func SparseSegmentSumWithNumSegments(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "SparseSegmentSumWithNumSegments",
- Input: []tf.Input{
- data, indices, segment_ids, num_segments,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Creates a dataset that executes a SQL query and emits rows of the result set.
//
// Arguments:
@@ -26467,6 +26462,53 @@ func Cross(scope *Scope, a tf.Output, b tf.Output) (product tf.Output) {
return op.Output(0)
}
+// AvgPool3DAttr is an optional argument to AvgPool3D.
+type AvgPool3DAttr func(optionalAttr)
+
+// AvgPool3DDataFormat sets the optional data_format attribute to value.
+//
+// value: The data format of the input and output data. With the
+// default format "NDHWC", the data is stored in the order of:
+// [batch, in_depth, in_height, in_width, in_channels].
+// Alternatively, the format could be "NCDHW", the data storage order is:
+// [batch, in_channels, in_depth, in_height, in_width].
+// If not specified, defaults to "NDHWC"
+func AvgPool3DDataFormat(value string) AvgPool3DAttr {
+ return func(m optionalAttr) {
+ m["data_format"] = value
+ }
+}
+
+// Performs 3D average pooling on the input.
+//
+// Arguments:
+// input: Shape `[batch, depth, rows, cols, channels]` tensor to pool over.
+// ksize: 1-D tensor of length 5. The size of the window for each dimension of
+// the input tensor. Must have `ksize[0] = ksize[4] = 1`.
+// strides: 1-D tensor of length 5. The stride of the sliding window for each
+// dimension of `input`. Must have `strides[0] = strides[4] = 1`.
+// padding: The type of padding algorithm to use.
+//
+// Returns The average pooled output tensor.
+func AvgPool3D(scope *Scope, input tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPool3DAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "AvgPool3D",
+ Input: []tf.Input{
+ input,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Performs a padding as a preprocess during a convolution.
//
// Similar to FusedResizeAndPadConv2d, this op allows for an optimized
@@ -30668,45 +30710,3 @@ func DecodeWav(scope *Scope, contents tf.Output, optional ...DecodeWavAttr) (aud
op := scope.AddOperation(opspec)
return op.Output(0), op.Output(1)
}
-
-// Concatenates a list of `N` tensors along the first dimension.
-//
-// The input tensors are all required to have size 1 in the first dimension.
-//
-// For example:
-//
-// ```
-// # 'x' is [[1, 4]]
-// # 'y' is [[2, 5]]
-// # 'z' is [[3, 6]]
-// parallel_concat([x, y, z]) => [[1, 4], [2, 5], [3, 6]] # Pack along first dim.
-// ```
-//
-// The difference between concat and parallel_concat is that concat requires all
-// of the inputs be computed before the operation will begin but doesn't require
-// that the input shapes be known during graph construction. Parallel concat
-// will copy pieces of the input into the output as they become available, in
-// some situations this can provide a performance benefit.
-//
-// Arguments:
-// values: Tensors to be concatenated. All must have size 1 in the first dimension
-// and same shape.
-// shape: the final shape of the result; should be equal to the shapes of any input
-// but with the number of input values in the first dimension.
-//
-// Returns The concatenated tensor.
-func ParallelConcat(scope *Scope, values []tf.Output, shape tf.Shape) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"shape": shape}
- opspec := tf.OpSpec{
- Type: "ParallelConcat",
- Input: []tf.Input{
- tf.OutputList(values),
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
diff --git a/tensorflow/go/operation.go b/tensorflow/go/operation.go
index 25ec718703..d6a37e0a86 100644
--- a/tensorflow/go/operation.go
+++ b/tensorflow/go/operation.go
@@ -45,6 +45,12 @@ func (op *Operation) NumOutputs() int {
return int(C.TF_OperationNumOutputs(op.c))
}
+// Device returns a specification of the device on which this operation
+// will be executed, or the empty string if there is no such specification.
+func (op *Operation) Device() string {
+ return C.GoString(C.TF_OperationDevice(op.c))
+}
+
// OutputListSize returns the size of the list of Outputs that is produced by a
// named output of op.
//
diff --git a/tensorflow/go/operation_test.go b/tensorflow/go/operation_test.go
index 06b65bdfb7..4af9e33ad0 100644
--- a/tensorflow/go/operation_test.go
+++ b/tensorflow/go/operation_test.go
@@ -228,6 +228,29 @@ func TestOperationConsumers(t *testing.T) {
}
}
+func TestOperationDevice(t *testing.T) {
+ graph := NewGraph()
+ v, err := NewTensor(float32(1.0))
+ if err != nil {
+ t.Fatal(err)
+ }
+ op, err := graph.AddOperation(OpSpec{
+ Type: "Const",
+ Name: "Const",
+ Attrs: map[string]interface{}{
+ "dtype": v.DataType(),
+ "value": v,
+ },
+ Device: "/device:GPU:0",
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got, want := op.Device(), "/device:GPU:0"; got != want {
+ t.Errorf("Got %q, want %q", got, want)
+ }
+}
+
func forceGC() {
var mem runtime.MemStats
runtime.ReadMemStats(&mem)
diff --git a/tensorflow/java/maven/hadoop/pom.xml b/tensorflow/java/maven/hadoop/pom.xml
index 0642be06fa..7391dfb965 100644
--- a/tensorflow/java/maven/hadoop/pom.xml
+++ b/tensorflow/java/maven/hadoop/pom.xml
@@ -1,12 +1,30 @@
-<project
- xmlns="http://maven.apache.org/POM/4.0.0"
- xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
- xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
- <!-- Placeholder pom which is replaced by TensorFlow ecosystem Hadoop pom during build -->
+<project xmlns="http://maven.apache.org/POM/4.0.0"
+ xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
<modelVersion>4.0.0</modelVersion>
- <description>TensorFlow TFRecord InputFormat/OutputFormat for Apache Hadoop</description>
+ <groupId>org.tensorflow</groupId>
<artifactId>hadoop</artifactId>
<packaging>jar</packaging>
+ <version>1.9.0</version>
+ <name>tensorflow-hadoop</name>
+ <url>https://www.tensorflow.org</url>
+ <description>TensorFlow TFRecord InputFormat/OutputFormat for Apache Hadoop</description>
+
+ <properties>
+ <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
+ <maven.compiler.source>1.6</maven.compiler.source>
+ <maven.compiler.target>1.6</maven.compiler.target>
+ <hadoop.version>2.6.0</hadoop.version>
+ <protobuf.version>3.3.1</protobuf.version>
+ <junit.version>4.11</junit.version>
+ </properties>
+
+ <licenses>
+ <license>
+ <name>Apache License Version 2.0</name>
+ <url>http://www.apache.org/licenses/LICENSE-2.0.txt</url>
+ </license>
+ </licenses>
<scm>
<url>https://github.com/tensorflow/ecosystem.git</url>
@@ -14,11 +32,161 @@
<developerConnection>scm:git:https://github.com/tensorflow/ecosystem.git</developerConnection>
</scm>
- <url>https://github.com/tensorflow/ecosystem/</url>
- <parent>
- <groupId>org.tensorflow</groupId>
- <artifactId>parentpom</artifactId>
- <version>1.9.0-rc0</version>
- <relativePath>../</relativePath>
- </parent>
-</project> \ No newline at end of file
+ <build>
+ <pluginManagement>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-gpg-plugin</artifactId>
+ <version>1.5</version>
+ <executions>
+ <execution>
+ <id>sign-artifacts</id>
+ <phase>verify</phase>
+ <goals>
+ <goal>sign</goal>
+ </goals>
+ </execution>
+ </executions>
+ </plugin>
+ </plugins>
+ </pluginManagement>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-source-plugin</artifactId>
+ <version>2.2.1</version>
+ <executions>
+ <execution>
+ <id>attach-sources</id>
+ <goals>
+ <goal>jar-no-fork</goal>
+ </goals>
+ </execution>
+ </executions>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-javadoc-plugin</artifactId>
+ <version>2.9.1</version>
+ <executions>
+ <execution>
+ <id>attach-javadocs</id>
+ <goals>
+ <goal>jar</goal>
+ </goals>
+ </execution>
+ </executions>
+ </plugin>
+ </plugins>
+ </build>
+
+ <dependencies>
+ <dependency>
+ <groupId>org.tensorflow</groupId>
+ <artifactId>proto</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-common</artifactId>
+ <version>${hadoop.version}</version>
+ <exclusions>
+ <exclusion>
+ <groupId>com.google.protobuf</groupId>
+ <artifactId>protobuf-java</artifactId>
+ </exclusion>
+ </exclusions>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-mapreduce-client-core</artifactId>
+ <version>${hadoop.version}</version>
+ <exclusions>
+ <exclusion>
+ <groupId>com.google.protobuf</groupId>
+ <artifactId>protobuf-java</artifactId>
+ </exclusion>
+ </exclusions>
+ </dependency>
+ <dependency>
+ <groupId>com.google.protobuf</groupId>
+ <artifactId>protobuf-java</artifactId>
+ <version>${protobuf.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>junit</groupId>
+ <artifactId>junit</artifactId>
+ <version>${junit.version}</version>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-mapreduce-client-jobclient</artifactId>
+ <version>${hadoop.version}</version>
+ <type>test-jar</type>
+ <optional>true</optional>
+ <scope>test</scope>
+ <exclusions>
+ <exclusion>
+ <groupId>com.google.protobuf</groupId>
+ <artifactId>protobuf-java</artifactId>
+ </exclusion>
+ </exclusions>
+ </dependency>
+ </dependencies>
+
+ <!-- Two profiles are used:
+ ossrh - deploys to ossrh/maven central
+ bintray - deploys to bintray/jcenter. -->
+ <profiles>
+ <profile>
+ <id>ossrh</id>
+ <distributionManagement>
+ <!-- Sonatype requirements from http://central.sonatype.org/pages/apache-maven.html -->
+ <snapshotRepository>
+ <id>ossrh</id>
+ <url>https://oss.sonatype.org/content/repositories/snapshots</url>
+ </snapshotRepository>
+ <repository>
+ <id>ossrh</id>
+ <url>https://oss.sonatype.org/service/local/staging/deploy/maven2/</url>
+ </repository>
+ </distributionManagement>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-gpg-plugin</artifactId>
+ </plugin>
+ </plugins>
+ </build>
+ </profile>
+ <profile>
+ <id>bintray</id>
+ <distributionManagement>
+ <!-- https://blog.bintray.com/2015/09/17/publishing-your-maven-project-to-bintray/ -->
+ <repository>
+ <id>bintray</id>
+ <url>https://api.bintray.com/maven/google/tensorflow/tensorflow/;publish=0</url>
+ </repository>
+ </distributionManagement>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-gpg-plugin</artifactId>
+ </plugin>
+ </plugins>
+ </build>
+ </profile>
+ </profiles>
+
+ <developers>
+ <developer>
+ <name>TensorFlowers</name>
+ <organization>TensorFlow</organization>
+ <organizationUrl>http://www.tensorflow.org</organizationUrl>
+ </developer>
+ </developers>
+</project>
diff --git a/tensorflow/java/maven/libtensorflow/pom.xml b/tensorflow/java/maven/libtensorflow/pom.xml
index a7fa9ea5cc..d44bdf8f81 100644
--- a/tensorflow/java/maven/libtensorflow/pom.xml
+++ b/tensorflow/java/maven/libtensorflow/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.9.0-rc1</version>
+ <version>1.9.0</version>
<relativePath>../</relativePath>
</parent>
<artifactId>libtensorflow</artifactId>
diff --git a/tensorflow/java/maven/libtensorflow_jni/pom.xml b/tensorflow/java/maven/libtensorflow_jni/pom.xml
index 83aae29f1e..e8925c6fb1 100644
--- a/tensorflow/java/maven/libtensorflow_jni/pom.xml
+++ b/tensorflow/java/maven/libtensorflow_jni/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.9.0-rc1</version>
+ <version>1.9.0</version>
<relativePath>../</relativePath>
</parent>
<artifactId>libtensorflow_jni</artifactId>
diff --git a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml
index 50bd8ee5f9..3bf4a2590c 100644
--- a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml
+++ b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.9.0-rc1</version>
+ <version>1.9.0</version>
<relativePath>../</relativePath>
</parent>
<artifactId>libtensorflow_jni_gpu</artifactId>
diff --git a/tensorflow/java/maven/pom.xml b/tensorflow/java/maven/pom.xml
index b4746794ea..b96dcf2888 100644
--- a/tensorflow/java/maven/pom.xml
+++ b/tensorflow/java/maven/pom.xml
@@ -6,7 +6,7 @@
<modelVersion>4.0.0</modelVersion>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.9.0-rc1</version>
+ <version>1.9.0</version>
<packaging>pom</packaging>
<url>https://www.tensorflow.org</url>
diff --git a/tensorflow/java/maven/proto/pom.xml b/tensorflow/java/maven/proto/pom.xml
index 618a2a124c..5581d864d7 100644
--- a/tensorflow/java/maven/proto/pom.xml
+++ b/tensorflow/java/maven/proto/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.9.0-rc1</version>
+ <version>1.9.0</version>
<relativePath>../</relativePath>
</parent>
<artifactId>proto</artifactId>
diff --git a/tensorflow/java/maven/run_inside_container.sh b/tensorflow/java/maven/run_inside_container.sh
index 2e771064e4..2240d6b7b9 100644
--- a/tensorflow/java/maven/run_inside_container.sh
+++ b/tensorflow/java/maven/run_inside_container.sh
@@ -203,7 +203,10 @@ download_tf_ecosystem() {
cd "${ECOSYSTEM_DIR}"
git clone "${TF_ECOSYSTEM_URL}"
cd ecosystem
- git checkout r${TF_VERSION}
+ # TF_VERSION is a semver string (<major>.<minor>.<patch>[-suffix])
+ # but the branch is just (r<major>.<minor>).
+ RELEASE_BRANCH=$(echo "${TF_VERSION}" | sed -e 's/\([0-9]\+\.[0-9]\+\)\.[0-9]\+.*/\1/')
+ git checkout r${RELEASE_BRANCH}
# Copy the TensorFlow Hadoop source
cp -r "${ECOSYSTEM_DIR}/ecosystem/hadoop/src" "${HADOOP_DIR}"
diff --git a/tensorflow/java/maven/spark-connector/pom.xml b/tensorflow/java/maven/spark-connector/pom.xml
index 19c752d08b..64956be02c 100644
--- a/tensorflow/java/maven/spark-connector/pom.xml
+++ b/tensorflow/java/maven/spark-connector/pom.xml
@@ -1,12 +1,23 @@
-<project
- xmlns="http://maven.apache.org/POM/4.0.0"
- xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
- xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
- <!-- Placeholder pom which is replaced by TensorFlow ecosystem Spark pom during build -->
+<?xml version="1.0" encoding="UTF-8"?>
+<project xmlns="http://maven.apache.org/POM/4.0.0"
+ xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
- <description>TensorFlow TFRecord connector for Apache Spark DataFrames</description>
- <artifactId>spark-connector</artifactId>
+ <groupId>org.tensorflow</groupId>
+ <artifactId>spark-connector_2.11</artifactId>
<packaging>jar</packaging>
+ <version>1.9.0</version>
+ <name>spark-tensorflow-connector</name>
+ <url>https://www.tensorflow.org</url>
+ <description>TensorFlow TFRecord connector for Apache Spark DataFrames</description>
+
+ <licenses>
+ <license>
+ <name>The Apache Software License, Version 2.0</name>
+ <url>http://www.apache.org/licenses/LICENSE-2.0.txt</url>
+ <distribution>repo</distribution>
+ </license>
+ </licenses>
<scm>
<url>https://github.com/tensorflow/ecosystem.git</url>
@@ -14,11 +25,325 @@
<developerConnection>scm:git:https://github.com/tensorflow/ecosystem.git</developerConnection>
</scm>
- <url>https://github.com/tensorflow/ecosystem/</url>
- <parent>
- <groupId>org.tensorflow</groupId>
- <artifactId>parentpom</artifactId>
- <version>1.9.0-rc0</version>
- <relativePath>../</relativePath>
- </parent>
-</project> \ No newline at end of file
+ <properties>
+ <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
+ <scala.maven.version>3.2.2</scala.maven.version>
+ <scala.binary.version>2.11</scala.binary.version>
+ <scalatest.maven.version>1.0</scalatest.maven.version>
+ <scala.test.version>2.2.6</scala.test.version>
+ <maven.compiler.version>3.0</maven.compiler.version>
+ <java.version>1.8</java.version>
+ <spark.version>2.3.0</spark.version>
+ <yarn.api.version>2.7.3</yarn.api.version>
+ <junit.version>4.11</junit.version>
+ </properties>
+
+ <build>
+ <pluginManagement>
+ <plugins>
+ <plugin>
+ <inherited>true</inherited>
+ <groupId>net.alchim31.maven</groupId>
+ <artifactId>scala-maven-plugin</artifactId>
+ <version>${scala.maven.version}</version>
+ <executions>
+ <execution>
+ <id>compile</id>
+ <goals>
+ <goal>add-source</goal>
+ <goal>compile</goal>
+ </goals>
+ <configuration>
+ <jvmArgs>
+ <jvmArg>-Xms256m</jvmArg>
+ <jvmArg>-Xmx512m</jvmArg>
+ </jvmArgs>
+ <args>
+ <arg>-g:vars</arg>
+ <arg>-deprecation</arg>
+ <arg>-feature</arg>
+ <arg>-unchecked</arg>
+ <arg>-Xfatal-warnings</arg>
+ <arg>-language:implicitConversions</arg>
+ <arg>-language:existentials</arg>
+ </args>
+ </configuration>
+ </execution>
+ <execution>
+ <id>test</id>
+ <goals>
+ <goal>add-source</goal>
+ <goal>testCompile</goal>
+ </goals>
+ </execution>
+ <execution>
+ <id>attach-javadocs</id>
+ <goals>
+ <goal>doc-jar</goal>
+ </goals>
+ </execution>
+ </executions>
+ <configuration>
+ <recompileMode>incremental</recompileMode>
+ <useZincServer>true</useZincServer>
+ <scalaVersion>${scala.binary.version}</scalaVersion>
+ <checkMultipleScalaVersions>false</checkMultipleScalaVersions>
+ </configuration>
+ </plugin>
+ <plugin>
+ <inherited>true</inherited>
+ <groupId>org.scalatest</groupId>
+ <artifactId>scalatest-maven-plugin</artifactId>
+ <version>${scalatest.maven.version}</version>
+ <executions>
+ <execution>
+ <id>scalaTest</id>
+ <phase>test</phase>
+ <goals>
+ <goal>test</goal>
+ </goals>
+ </execution>
+ </executions>
+ </plugin>
+ <!-- Shade protobuf dependency. -->
+ <plugin>
+ <artifactId>maven-shade-plugin</artifactId>
+ <version>3.1.0</version>
+ <executions>
+ <execution>
+ <phase>package</phase>
+ <goals>
+ <goal>shade</goal>
+ </goals>
+ <configuration>
+ <minimizeJar>true</minimizeJar>
+ <artifactSet>
+ <includes>
+ <include>com.google.protobuf:protobuf-java</include>
+ <include>org.tensorflow:hadoop</include>
+ <include>org.tensorflow:proto</include>
+ </includes>
+ </artifactSet>
+ <filters>
+ <filter>
+ <!-- Remove the source to keep the result smaller. -->
+ <artifact>com.google.protobuf:protobuf-java</artifact>
+ <excludes>
+ <exclude>**/*.java</exclude>
+ </excludes>
+ </filter>
+ </filters>
+ <relocations>
+ <relocation>
+ <pattern>com.google.protobuf</pattern>
+ <shadedPattern>
+ org.tensorflow.spark.shaded.com.google.protobuf
+ </shadedPattern>
+ </relocation>
+ </relocations>
+ </configuration>
+ </execution>
+ </executions>
+ </plugin>
+ <!-- GPG signed components: http://central.sonatype.org/pages/apache-maven.html#gpg-signed-components -->
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-gpg-plugin</artifactId>
+ <version>1.5</version>
+ <executions>
+ <execution>
+ <id>sign-artifacts</id>
+ <phase>verify</phase>
+ <goals>
+ <goal>sign</goal>
+ </goals>
+ </execution>
+ </executions>
+ </plugin>
+ </plugins>
+ </pluginManagement>
+ <plugins>
+ <plugin>
+ <groupId>net.alchim31.maven</groupId>
+ <artifactId>scala-maven-plugin</artifactId>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-shade-plugin</artifactId>
+ </plugin>
+ <plugin>
+ <groupId>org.scalatest</groupId>
+ <artifactId>scalatest-maven-plugin</artifactId>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-compiler-plugin</artifactId>
+ <version>${maven.compiler.version}</version>
+ <configuration>
+ <source>${java.version}</source>
+ <target>${java.version}</target>
+ </configuration>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-source-plugin</artifactId>
+ <version>2.2.1</version>
+ <executions>
+ <execution>
+ <id>attach-sources</id>
+ <goals>
+ <goal>jar-no-fork</goal>
+ </goals>
+ </execution>
+ </executions>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-javadoc-plugin</artifactId>
+ <version>2.9.1</version>
+ <executions>
+ <execution>
+ <id>attach-javadocs</id>
+ <goals>
+ <goal>jar</goal>
+ </goals>
+ </execution>
+ </executions>
+ </plugin>
+ </plugins>
+ </build>
+
+ <profiles>
+ <profile>
+ <id>test</id>
+ <activation>
+ <activeByDefault>true</activeByDefault>
+ <property>
+ <name>!NEVERSETME</name>
+ </property>
+ </activation>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>net.alchim31.maven</groupId>
+ <artifactId>scala-maven-plugin</artifactId>
+ </plugin>
+ </plugins>
+ </build>
+ <dependencyManagement>
+ <dependencies>
+ <dependency>
+ <groupId>org.scalatest</groupId>
+ <artifactId>scalatest_${scala.binary.version}</artifactId>
+ <version>${scala.test.version}</version>
+ <scope>test</scope>
+ </dependency>
+ </dependencies>
+ </dependencyManagement>
+ <dependencies>
+ <dependency>
+ <groupId>org.scalatest</groupId>
+ <artifactId>scalatest_${scala.binary.version}</artifactId>
+ <scope>test</scope>
+ </dependency>
+ </dependencies>
+ </profile>
+
+ <!-- Two profiles are used:
+ ossrh - deploys to ossrh/maven central
+ bintray - deploys to bintray/jcenter. -->
+ <profile>
+ <id>ossrh</id>
+ <distributionManagement>
+ <!-- Sonatype requirements from http://central.sonatype.org/pages/apache-maven.html -->
+ <snapshotRepository>
+ <id>ossrh</id>
+ <url>https://oss.sonatype.org/content/repositories/snapshots</url>
+ </snapshotRepository>
+ <repository>
+ <id>ossrh</id>
+ <url>https://oss.sonatype.org/service/local/staging/deploy/maven2/</url>
+ </repository>
+ </distributionManagement>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-gpg-plugin</artifactId>
+ </plugin>
+ </plugins>
+ </build>
+ </profile>
+ <profile>
+ <id>bintray</id>
+ <distributionManagement>
+ <!-- https://blog.bintray.com/2015/09/17/publishing-your-maven-project-to-bintray/ -->
+ <repository>
+ <id>bintray</id>
+ <url>https://api.bintray.com/maven/google/tensorflow/tensorflow/;publish=0</url>
+ </repository>
+ </distributionManagement>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-gpg-plugin</artifactId>
+ </plugin>
+ </plugins>
+ </build>
+ </profile>
+ </profiles>
+
+ <developers>
+ <developer>
+ <name>TensorFlowers</name>
+ <organization>TensorFlow</organization>
+ <organizationUrl>http://www.tensorflow.org</organizationUrl>
+ </developer>
+ </developers>
+
+ <dependencies>
+ <dependency>
+ <groupId>org.tensorflow</groupId>
+ <artifactId>hadoop</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-core_${scala.binary.version}</artifactId>
+ <version>${spark.version}</version>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-sql_${scala.binary.version}</artifactId>
+ <version>${spark.version}</version>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-mllib_${scala.binary.version}</artifactId>
+ <version>${spark.version}</version>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-yarn-api</artifactId>
+ <version>${yarn.api.version}</version>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-mllib_${scala.binary.version}</artifactId>
+ <version>${spark.version}</version>
+ <type>test-jar</type>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>junit</groupId>
+ <artifactId>junit</artifactId>
+ <version>${junit.version}</version>
+ <scope>test</scope>
+ </dependency>
+ </dependencies>
+</project>
diff --git a/tensorflow/java/maven/tensorflow/pom.xml b/tensorflow/java/maven/tensorflow/pom.xml
index 157c4b8e82..92e15aa2c7 100644
--- a/tensorflow/java/maven/tensorflow/pom.xml
+++ b/tensorflow/java/maven/tensorflow/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.9.0-rc1</version>
+ <version>1.9.0</version>
<relativePath>../</relativePath>
</parent>
<artifactId>tensorflow</artifactId>
diff --git a/tensorflow/java/src/gen/cc/java_defs.h b/tensorflow/java/src/gen/cc/java_defs.h
index f5f54bf4d3..d9d6f8adc8 100644
--- a/tensorflow/java/src/gen/cc/java_defs.h
+++ b/tensorflow/java/src/gen/cc/java_defs.h
@@ -16,9 +16,9 @@ limitations under the License.
#ifndef TENSORFLOW_JAVA_SRC_GEN_CC_JAVA_DEFS_H_
#define TENSORFLOW_JAVA_SRC_GEN_CC_JAVA_DEFS_H_
-#include <string>
#include <list>
#include <map>
+#include <string>
#include <utility>
namespace tensorflow {
diff --git a/tensorflow/java/src/gen/cc/op_generator.h b/tensorflow/java/src/gen/cc/op_generator.h
index 759d800ecf..05decd6b54 100644
--- a/tensorflow/java/src/gen/cc/op_generator.h
+++ b/tensorflow/java/src/gen/cc/op_generator.h
@@ -19,10 +19,10 @@ limitations under the License.
#include <string>
#include <vector>
-#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/api_def.pb.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/env.h"
#include "tensorflow/java/src/gen/cc/op_specs.h"
namespace tensorflow {
diff --git a/tensorflow/java/src/gen/cc/op_specs.cc b/tensorflow/java/src/gen/cc/op_specs.cc
index 63e99fbb04..941ab2699c 100644
--- a/tensorflow/java/src/gen/cc/op_specs.cc
+++ b/tensorflow/java/src/gen/cc/op_specs.cc
@@ -14,9 +14,9 @@ limitations under the License.
==============================================================================*/
#include <map>
-#include <vector>
#include <string>
#include <utility>
+#include <vector>
#include "re2/re2.h"
#include "tensorflow/core/framework/op.h"
@@ -50,7 +50,7 @@ class TypeResolver {
// For example, if the argument's datatype is DT_STRING, this method will
// return "java.lang.String", so the argument can become "Operand<String>"
// in the Ops API
- Type TypeOf(const OpDef_ArgDef& arg_def, bool *iterable_out);
+ Type TypeOf(const OpDef_ArgDef& arg_def, bool* iterable_out);
// Returns types of an input attribute
//
@@ -62,7 +62,7 @@ class TypeResolver {
// <java.lang.Float, float>, so the attribute can be used as a "Float" object
// in the Ops API and casted to a "float" when passing through the JNI layer.
std::pair<Type, Type> TypesOf(const OpDef_AttrDef& attr_def,
- bool *iterable_out);
+ bool* iterable_out);
// Returns true if the type of this attribute has already been resolved
bool IsAttributeVisited(const string& attr_name) {
@@ -89,8 +89,7 @@ class TypeResolver {
}
};
-Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def,
- bool* iterable_out) {
+Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def, bool* iterable_out) {
*iterable_out = false;
if (!arg_def.number_attr().empty()) {
// when number_attr is set, argument has to be a list of tensors
@@ -154,13 +153,13 @@ Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def,
} else {
LOG(FATAL) << "Cannot resolve data type of argument \"" << arg_def.name()
- << "\" in operation \"" << op_def_.name() << "\"";
+ << "\" in operation \"" << op_def_.name() << "\"";
}
return type;
}
std::pair<Type, Type> TypeResolver::TypesOf(const OpDef_AttrDef& attr_def,
- bool* iterable_out) {
+ bool* iterable_out) {
std::pair<Type, Type> types = MakeTypePair(Type::Wildcard());
*iterable_out = false;
StringPiece attr_type = attr_def.type();
@@ -185,7 +184,7 @@ std::pair<Type, Type> TypeResolver::TypesOf(const OpDef_AttrDef& attr_def,
} else if (attr_type == "tensor") {
types = MakeTypePair(Type::Class("Tensor", "org.tensorflow")
- .add_parameter(Type::Wildcard()));
+ .add_parameter(Type::Wildcard()));
} else if (attr_type == "type") {
Type type = *iterable_out ? Type::Wildcard() : NextGeneric();
@@ -196,7 +195,7 @@ std::pair<Type, Type> TypeResolver::TypesOf(const OpDef_AttrDef& attr_def,
} else {
LOG(FATAL) << "Cannot resolve data type for attribute \"" << attr_type
- << "\" in operation \"" << op_def_.name() << "\"";
+ << "\" in operation \"" << op_def_.name() << "\"";
}
visited_attrs_.insert(std::make_pair(attr_def.name(), types.first));
return types;
@@ -219,47 +218,43 @@ string SnakeToCamelCase(const string& str, bool upper = false) {
return result;
}
-bool FindAndCut(re2::StringPiece* input, const RE2& expr,
- re2::StringPiece* before_match, re2::StringPiece* ret_match = nullptr) {
- re2::StringPiece match;
- if (!expr.Match(*input, 0, input->size(), RE2::UNANCHORED, &match, 1)) {
- return false;
- }
- before_match->set(input->data(), match.begin() - input->begin());
- input->remove_prefix(match.end() - before_match->begin());
- if (ret_match != nullptr) {
- *ret_match = match;
- }
+bool FindAndCut(string* input, const RE2& expr, string* before_match,
+ string* ret_match = nullptr) {
+ string match;
+ if (!RE2::PartialMatch(*input, expr, &match)) return false;
+ *before_match = input->substr(0, input->find(match));
+ *input = input->substr(before_match->size() + match.size());
+ if (ret_match != nullptr) *ret_match = match;
return true;
}
-string ParseDocumentation(re2::StringPiece input) {
+string ParseDocumentation(const string& inp) {
std::stringstream javadoc_text;
// TODO(karllessard) This is a very minimalist utility method for converting
// markdown syntax, as found in ops descriptions, to Javadoc/html tags. Check
// for alternatives to increase the level of support for markups.
std::vector<string> markups_subexpr;
- markups_subexpr.push_back("\n+\\*\\s+"); // lists
- markups_subexpr.push_back("\n{2,}"); // paragraphs
+ markups_subexpr.push_back("\n+\\*\\s+"); // lists
+ markups_subexpr.push_back("\n{2,}"); // paragraphs
markups_subexpr.push_back("`{3,}\\s*[^\\s\n]*\\s*\n"); // code blocks
- markups_subexpr.push_back("`+"); // inlined code and code blocks
+ markups_subexpr.push_back("`+"); // inlined code and code blocks
markups_subexpr.push_back("\\*{1,2}\\b"); // text emphasis
- markups_subexpr.push_back("\\["); // hyperlinks
- const RE2 markup_expr(str_util::Join(markups_subexpr, "|"));
+ markups_subexpr.push_back("\\["); // hyperlinks
+ const RE2 markup_expr("(" + str_util::Join(markups_subexpr, "|") + ")");
bool in_list = false;
+ string input = inp;
while (true) {
- re2::StringPiece text;
- re2::StringPiece markup;
+ string text, markup;
if (!FindAndCut(&input, markup_expr, &text, &markup)) {
javadoc_text << input;
break; // end of loop
}
javadoc_text << text;
- if (markup.starts_with("\n")) {
+ if (str_util::StartsWith(markup, "\n")) {
javadoc_text << "\n";
- if (markup.contains("*")) {
+ if (str_util::StrContains(markup, "*")) {
// new list item
javadoc_text << (in_list ? "</li>\n" : "<ul>\n") << "<li>\n";
in_list = true;
@@ -267,18 +262,18 @@ string ParseDocumentation(re2::StringPiece input) {
// end of list
javadoc_text << "</li>\n</ul>\n";
in_list = false;
- } else if (!input.starts_with("```")) {
+ } else if (!str_util::StartsWith(input, "```")) {
// new paragraph (not required if a <pre> block follows)
javadoc_text << "<p>\n";
}
- } else if (markup.starts_with("```")) {
+ } else if (str_util::StartsWith(markup, "```")) {
// code blocks
- if (FindAndCut(&input, "```\\s*\n*", &text)) {
+ if (FindAndCut(&input, "(```\\s*\n*)", &text)) {
javadoc_text << "<pre>{@code\n" << text << "}</pre>\n";
} else {
javadoc_text << markup;
}
- } else if (markup.starts_with("`")) {
+ } else if (str_util::StartsWith("(" + markup + ")", "`")) {
// inlined code
if (FindAndCut(&input, markup, &text)) {
javadoc_text << "{@code " << text << "}";
@@ -287,26 +282,28 @@ string ParseDocumentation(re2::StringPiece input) {
}
} else if (markup == "**") {
// text emphasis (strong)
- if (FindAndCut(&input, "\\b\\*{2}", &text)) {
+ if (FindAndCut(&input, "(\\b\\*{2})", &text)) {
javadoc_text << "<b>" << ParseDocumentation(text) << "</b>";
} else {
javadoc_text << markup;
}
} else if (markup == "*") {
// text emphasis (normal)
- if (FindAndCut(&input, "\\b\\*{1}", &text)) {
+ if (FindAndCut(&input, "(\\b\\*{1})", &text)) {
javadoc_text << "<i>" << ParseDocumentation(text) << "</i>";
} else {
javadoc_text << markup;
}
- } else if (markup.starts_with("[")) {
+ } else if (str_util::StartsWith(markup, "[")) {
// hyperlinks
string label;
string link;
- if (RE2::Consume(&input, "([^\\[]+)\\]\\((http.+)\\)", &label, &link)) {
+ if (RE2::PartialMatch(input, "([^\\[]+)\\]\\((http.+)\\)", &label,
+ &link) &&
+ str_util::StartsWith(input, label + link)) {
+ input = input.substr(label.size() + link.size());
javadoc_text << "<a href=\"" << link << "\">"
- << ParseDocumentation(label)
- << "</a>";
+ << ParseDocumentation(label) << "</a>";
} else {
javadoc_text << markup;
}
@@ -319,57 +316,56 @@ string ParseDocumentation(re2::StringPiece input) {
}
ArgumentSpec CreateInput(const OpDef_ArgDef& input_def,
- const ApiDef::Arg& input_api_def, TypeResolver* type_resolver) {
+ const ApiDef::Arg& input_api_def,
+ TypeResolver* type_resolver) {
bool iterable = false;
Type type = type_resolver->TypeOf(input_def, &iterable);
- Type var_type = Type::Interface("Operand", "org.tensorflow")
- .add_parameter(type);
+ Type var_type =
+ Type::Interface("Operand", "org.tensorflow").add_parameter(type);
if (iterable) {
var_type = Type::IterableOf(var_type);
}
- return ArgumentSpec(input_api_def.name(),
+ return ArgumentSpec(
+ input_api_def.name(),
Variable::Create(SnakeToCamelCase(input_api_def.rename_to()), var_type),
- type,
- ParseDocumentation(input_api_def.description()),
- iterable);
+ type, ParseDocumentation(input_api_def.description()), iterable);
}
AttributeSpec CreateAttribute(const OpDef_AttrDef& attr_def,
- const ApiDef::Attr& attr_api_def, TypeResolver* type_resolver) {
+ const ApiDef::Attr& attr_api_def,
+ TypeResolver* type_resolver) {
bool iterable = false;
std::pair<Type, Type> types = type_resolver->TypesOf(attr_def, &iterable);
- Type var_type = types.first.kind() == Type::GENERIC ?
- Type::Class("Class").add_parameter(types.first) : types.first;
+ Type var_type = types.first.kind() == Type::GENERIC
+ ? Type::Class("Class").add_parameter(types.first)
+ : types.first;
if (iterable) {
var_type = Type::ListOf(var_type);
}
- return AttributeSpec(attr_api_def.name(),
+ return AttributeSpec(
+ attr_api_def.name(),
Variable::Create(SnakeToCamelCase(attr_api_def.rename_to()), var_type),
- types.first,
- types.second,
- ParseDocumentation(attr_api_def.description()),
- iterable,
- attr_api_def.has_default_value());
+ types.first, types.second, ParseDocumentation(attr_api_def.description()),
+ iterable, attr_api_def.has_default_value());
}
ArgumentSpec CreateOutput(const OpDef_ArgDef& output_def,
- const ApiDef::Arg& output_api, TypeResolver* type_resolver) {
+ const ApiDef::Arg& output_api,
+ TypeResolver* type_resolver) {
bool iterable = false;
Type type = type_resolver->TypeOf(output_def, &iterable);
- Type var_type = Type::Class("Output", "org.tensorflow")
- .add_parameter(type);
+ Type var_type = Type::Class("Output", "org.tensorflow").add_parameter(type);
if (iterable) {
var_type = Type::ListOf(var_type);
}
- return ArgumentSpec(output_api.name(),
+ return ArgumentSpec(
+ output_api.name(),
Variable::Create(SnakeToCamelCase(output_api.rename_to()), var_type),
- type,
- ParseDocumentation(output_api.description()),
- iterable);
+ type, ParseDocumentation(output_api.description()), iterable);
}
EndpointSpec CreateEndpoint(const OpDef& op_def, const ApiDef& api_def,
- const ApiDef_Endpoint& endpoint_def) {
+ const ApiDef_Endpoint& endpoint_def) {
std::vector<string> name_tokens = str_util::Split(endpoint_def.name(), ".");
string package;
string name;
@@ -377,27 +373,25 @@ EndpointSpec CreateEndpoint(const OpDef& op_def, const ApiDef& api_def,
package = name_tokens.at(0);
name = name_tokens.at(1);
} else {
- package = kDefaultEndpointPackage;
+ package = "core"; // generate unclassified ops in the 'core' package
name = name_tokens.at(0);
}
- return EndpointSpec(package,
- name,
- Javadoc::Create(ParseDocumentation(api_def.summary()))
- .details(ParseDocumentation(api_def.description())));
+ return EndpointSpec(package, name,
+ Javadoc::Create(ParseDocumentation(api_def.summary()))
+ .details(ParseDocumentation(api_def.description())));
}
} // namespace
OpSpec OpSpec::Create(const OpDef& op_def, const ApiDef& api_def) {
- OpSpec op(api_def.graph_op_name(),
- api_def.visibility() == ApiDef::HIDDEN,
- op_def.deprecation().explanation());
+ OpSpec op(api_def.graph_op_name(), api_def.visibility() == ApiDef::HIDDEN,
+ op_def.deprecation().explanation());
TypeResolver type_resolver(op_def);
for (const string& next_input_name : api_def.arg_order()) {
for (int i = 0; i < op_def.input_arg().size(); ++i) {
if (op_def.input_arg(i).name() == next_input_name) {
op.inputs_.push_back(CreateInput(op_def.input_arg(i), api_def.in_arg(i),
- &type_resolver));
+ &type_resolver));
break;
}
}
@@ -406,8 +400,8 @@ OpSpec OpSpec::Create(const OpDef& op_def, const ApiDef& api_def) {
// do not parse attributes already visited, they have probably been inferred
// before as an input argument type
if (!type_resolver.IsAttributeVisited(op_def.attr(i).name())) {
- AttributeSpec attr = CreateAttribute(op_def.attr(i), api_def.attr(i),
- &type_resolver);
+ AttributeSpec attr =
+ CreateAttribute(op_def.attr(i), api_def.attr(i), &type_resolver);
// attributes with a default value are optional
if (attr.has_default_value() && attr.type().kind() != Type::GENERIC) {
op.optional_attributes_.push_back(attr);
@@ -417,8 +411,8 @@ OpSpec OpSpec::Create(const OpDef& op_def, const ApiDef& api_def) {
}
}
for (int i = 0; i < op_def.output_arg().size(); ++i) {
- op.outputs_.push_back(CreateOutput(op_def.output_arg(i), api_def.out_arg(i),
- &type_resolver));
+ op.outputs_.push_back(
+ CreateOutput(op_def.output_arg(i), api_def.out_arg(i), &type_resolver));
}
for (const auto& endpoint_def : api_def.endpoint()) {
op.endpoints_.push_back(CreateEndpoint(op_def, api_def, endpoint_def));
diff --git a/tensorflow/java/src/gen/cc/op_specs.h b/tensorflow/java/src/gen/cc/op_specs.h
index 3b53c730df..30ecb8ce53 100644
--- a/tensorflow/java/src/gen/cc/op_specs.h
+++ b/tensorflow/java/src/gen/cc/op_specs.h
@@ -19,9 +19,9 @@ limitations under the License.
#include <string>
#include <vector>
-#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/api_def.pb.h"
#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/java/src/gen/cc/java_defs.h"
namespace tensorflow {
@@ -38,9 +38,8 @@ class EndpointSpec {
// javadoc: the endpoint class documentation
// TODO(annarev): hardcode depcreated to false until deprecated is possible
EndpointSpec(const string& package, const string& name,
- const Javadoc& javadoc)
- : package_(package), name_(name), javadoc_(javadoc),
- deprecated_(false) {}
+ const Javadoc& javadoc)
+ : package_(package), name_(name), javadoc_(javadoc), deprecated_(false) {}
const string& package() const { return package_; }
const string& name() const { return name_; }
@@ -63,10 +62,13 @@ class ArgumentSpec {
// type: the tensor type of this argument
// description: a description of this argument, in javadoc
// iterable: true if this argument is a list
- ArgumentSpec(const string& op_def_name, const Variable& var,
- const Type& type, const string& description, bool iterable)
- : op_def_name_(op_def_name), var_(var), type_(type),
- description_(description), iterable_(iterable) {}
+ ArgumentSpec(const string& op_def_name, const Variable& var, const Type& type,
+ const string& description, bool iterable)
+ : op_def_name_(op_def_name),
+ var_(var),
+ type_(type),
+ description_(description),
+ iterable_(iterable) {}
const string& op_def_name() const { return op_def_name_; }
const Variable& var() const { return var_; }
@@ -94,11 +96,16 @@ class AttributeSpec {
// iterable: true if this attribute is a list
// has_default_value: true if this attribute has a default value if not set
AttributeSpec(const string& op_def_name, const Variable& var,
- const Type& type, const Type& jni_type, const string& description,
- bool iterable, bool has_default_value)
- : op_def_name_(op_def_name), var_(var), type_(type),
- description_(description), iterable_(iterable),
- jni_type_(jni_type), has_default_value_(has_default_value) {}
+ const Type& type, const Type& jni_type,
+ const string& description, bool iterable,
+ bool has_default_value)
+ : op_def_name_(op_def_name),
+ var_(var),
+ type_(type),
+ description_(description),
+ iterable_(iterable),
+ jni_type_(jni_type),
+ has_default_value_(has_default_value) {}
const string& op_def_name() const { return op_def_name_; }
const Variable& var() const { return var_; }
@@ -147,9 +154,10 @@ class OpSpec {
// hidden: true if this op should not be visible through the Graph Ops API
// deprecation_explanation: message to show if all endpoints are deprecated
explicit OpSpec(const string& graph_op_name, bool hidden,
- const string& deprecation_explanation)
- : graph_op_name_(graph_op_name), hidden_(hidden),
- deprecation_explanation_(deprecation_explanation) {}
+ const string& deprecation_explanation)
+ : graph_op_name_(graph_op_name),
+ hidden_(hidden),
+ deprecation_explanation_(deprecation_explanation) {}
const string graph_op_name_;
const bool hidden_;
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Input.java b/tensorflow/java/src/main/java/org/tensorflow/Input.java
new file mode 100644
index 0000000000..13bc463e7d
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/Input.java
@@ -0,0 +1,48 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow;
+
+/**
+ * Interface implemented by operands of a TensorFlow operation.
+ *
+ * <p>Example usage:
+ *
+ * <pre>{@code
+ * // The "decodeJpeg" operation can be used as input to the "cast" operation
+ * Input decodeJpeg = ops.image().decodeJpeg(...);
+ * ops.math().cast(decodeJpeg, DataType.FLOAT);
+ *
+ * // The output "y" of the "unique" operation can be used as input to the "cast" operation
+ * Output y = ops.array().unique(...).y();
+ * ops.math().cast(y, DataType.FLOAT);
+ *
+ * // The "split" operation can be used as input list to the "concat" operation
+ * Iterable<? extends Input> split = ops.array().split(...);
+ * ops.array().concat(0, split);
+ * }</pre>
+ */
+public interface Input<T> {
+
+ /**
+ * Returns the symbolic handle of a tensor.
+ *
+ * <p>Inputs to TensorFlow operations are outputs of another TensorFlow operation. This method is
+ * used to obtain a symbolic handle that represents the computation of the input.
+ *
+ * @see OperationBuilder#addInput(Output)
+ */
+ Output<T> asOutput();
+}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFBool.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFBool.java
new file mode 100644
index 0000000000..ab34f6aa12
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/types/TFBool.java
@@ -0,0 +1,30 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// GENERATED FILE. To update, edit tftypes.pl instead.
+
+package org.tensorflow.types;
+
+import org.tensorflow.DataType;
+
+/** Represents a boolean. */
+public class TFBool implements TFType {
+ private TFBool() {}
+ static {
+ Types.typeCodes.put(TFBool.class, DataType.BOOL);
+ }
+ static {
+ Types.scalars.put(TFBool.class, false);
+ }
+}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFDouble.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFDouble.java
new file mode 100644
index 0000000000..49e5d9f2f3
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/types/TFDouble.java
@@ -0,0 +1,30 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// GENERATED FILE. To update, edit tftypes.pl instead.
+
+package org.tensorflow.types;
+
+import org.tensorflow.DataType;
+
+/** Represents a 64-bit double precision floating point number. */
+public class TFDouble implements TFType {
+ private TFDouble() {}
+ static {
+ Types.typeCodes.put(TFDouble.class, DataType.DOUBLE);
+ }
+ static {
+ Types.scalars.put(TFDouble.class, 0.0);
+ }
+}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFFloat.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFFloat.java
new file mode 100644
index 0000000000..8426ee41f0
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/types/TFFloat.java
@@ -0,0 +1,30 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// GENERATED FILE. To update, edit tftypes.pl instead.
+
+package org.tensorflow.types;
+
+import org.tensorflow.DataType;
+
+/** Represents a 32-bit single precision floating point number. */
+public class TFFloat implements TFType {
+ private TFFloat() {}
+ static {
+ Types.typeCodes.put(TFFloat.class, DataType.FLOAT);
+ }
+ static {
+ Types.scalars.put(TFFloat.class, 0f);
+ }
+}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFInt32.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFInt32.java
new file mode 100644
index 0000000000..3947b6ad09
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/types/TFInt32.java
@@ -0,0 +1,30 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// GENERATED FILE. To update, edit tftypes.pl instead.
+
+package org.tensorflow.types;
+
+import org.tensorflow.DataType;
+
+/** Represents a 32-bit signed integer. */
+public class TFInt32 implements TFType {
+ private TFInt32() {}
+ static {
+ Types.typeCodes.put(TFInt32.class, DataType.INT32);
+ }
+ static {
+ Types.scalars.put(TFInt32.class, 0);
+ }
+}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFInt64.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFInt64.java
new file mode 100644
index 0000000000..ccdded8693
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/types/TFInt64.java
@@ -0,0 +1,30 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// GENERATED FILE. To update, edit tftypes.pl instead.
+
+package org.tensorflow.types;
+
+import org.tensorflow.DataType;
+
+/** Represents a 64-bit signed integer. */
+public class TFInt64 implements TFType {
+ private TFInt64() {}
+ static {
+ Types.typeCodes.put(TFInt64.class, DataType.INT64);
+ }
+ static {
+ Types.scalars.put(TFInt64.class, 0L);
+ }
+}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFString.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFString.java
new file mode 100644
index 0000000000..e7327e8c57
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/types/TFString.java
@@ -0,0 +1,27 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// GENERATED FILE. To update, edit tftypes.pl instead.
+
+package org.tensorflow.types;
+
+import org.tensorflow.DataType;
+
+/** Represents an arbitrary sequence of bytes. */
+public class TFString implements TFType {
+ private TFString() {}
+ static {
+ Types.typeCodes.put(TFString.class, DataType.STRING);
+ }
+}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFType.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFType.java
new file mode 100644
index 0000000000..562953ac9d
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/types/TFType.java
@@ -0,0 +1,20 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+package org.tensorflow.types;
+
+/**
+ * A marker interface for classes representing TensorFlow types.
+ */
+public interface TFType {}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFUInt8.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFUInt8.java
new file mode 100644
index 0000000000..d7305ca5a8
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/types/TFUInt8.java
@@ -0,0 +1,30 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// GENERATED FILE. To update, edit tftypes.pl instead.
+
+package org.tensorflow.types;
+
+import org.tensorflow.DataType;
+
+/** Represents an 8-bit unsigned integer. */
+public class TFUInt8 implements TFType {
+ private TFUInt8() {}
+ static {
+ Types.typeCodes.put(TFUInt8.class, DataType.UINT8);
+ }
+ static {
+ Types.scalars.put(TFUInt8.class, (byte)0);
+ }
+}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/Types.java b/tensorflow/java/src/main/java/org/tensorflow/types/Types.java
new file mode 100644
index 0000000000..976cd9fd34
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/types/Types.java
@@ -0,0 +1,52 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+package org.tensorflow.types;
+
+import java.util.HashMap;
+import java.util.Map;
+import org.tensorflow.DataType;
+
+/**
+ * Utility class for managing the representation of TensorFlow types as Java
+ * types. For each TensorFlow type (e.g., int32), there is a corresponding Java
+ * type (e.g., TFInt32) that represents it at compile time and a corresponding
+ * class object (e.g., TFInt32.class) that represents it at run time. There is
+ * also an enumeration value in DataType that can be used to represent the
+ * type, though that should rarely be required.
+ */
+public class Types {
+
+ private Types() {} // not instantiable
+
+ static final Map<Class<?>, DataType> typeCodes = new HashMap<>();
+
+ /** Returns the DataType value corresponding to a TensorFlow type class. */
+ public static DataType dataType(Class<? extends TFType> c) {
+ DataType dtype = typeCodes.get(c);
+ if (dtype == null) {
+ throw new IllegalArgumentException("" + c + " is not a TensorFlow type.");
+ }
+ return dtype;
+ }
+
+ static final Map<Class<?>, Object> scalars = new HashMap<>();
+
+ /** Returns the zero value of type described by {@code c}, or null if
+ * the type (e.g., string) is not numeric and therefore has no zero value.
+ */
+ public static Object zeroValue(Class<? extends TFType> c) {
+ return scalars.get(c);
+ }
+}
diff --git a/tensorflow/java/src/main/native/session_jni.cc b/tensorflow/java/src/main/native/session_jni.cc
index cb54daf137..8b11525785 100644
--- a/tensorflow/java/src/main/native/session_jni.cc
+++ b/tensorflow/java/src/main/native/session_jni.cc
@@ -86,20 +86,22 @@ JNIEXPORT jlong JNICALL Java_org_tensorflow_Session_allocate2(
TF_Graph* graph = reinterpret_cast<TF_Graph*>(graph_handle);
TF_Status* status = TF_NewStatus();
TF_SessionOptions* opts = TF_NewSessionOptions();
- const char* ctarget = nullptr;
jbyte* cconfig = nullptr;
- if (target != nullptr) {
- ctarget = env->GetStringUTFChars(target, nullptr);
- }
if (config != nullptr) {
cconfig = env->GetByteArrayElements(config, nullptr);
TF_SetConfig(opts, cconfig,
static_cast<size_t>(env->GetArrayLength(config)), status);
if (!throwExceptionIfNotOK(env, status)) {
env->ReleaseByteArrayElements(config, cconfig, JNI_ABORT);
+ TF_DeleteSessionOptions(opts);
+ TF_DeleteStatus(status);
return 0;
}
}
+ const char* ctarget = nullptr;
+ if (target != nullptr) {
+ ctarget = env->GetStringUTFChars(target, nullptr);
+ }
TF_Session* session = TF_NewSession(graph, opts, status);
if (config != nullptr) {
env->ReleaseByteArrayElements(config, cconfig, JNI_ABORT);
diff --git a/tensorflow/java/src/test/java/org/tensorflow/SavedModelBundleTest.java b/tensorflow/java/src/test/java/org/tensorflow/SavedModelBundleTest.java
index 7922f3329c..b063b6f1cd 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/SavedModelBundleTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/SavedModelBundleTest.java
@@ -47,7 +47,7 @@ public class SavedModelBundleTest {
fail("not expected");
} catch (org.tensorflow.TensorFlowException e) {
// expected exception
- assertTrue(e.getMessage().contains("SavedModel not found"));
+ assertTrue(e.getMessage().contains("Could not find SavedModel"));
}
}
}
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index ebfcfff4a5..924db54cbc 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -73,7 +73,7 @@ py_library(
visibility = [
"//tensorflow:__pkg__",
"//tensorflow/python/tools:__pkg__",
- "//tensorflow/tools/api/generator:__pkg__",
+ "//tensorflow/python/tools/api/generator:__pkg__",
],
deps = [
":array_ops",
@@ -822,6 +822,7 @@ py_library(
":platform",
":registry",
":tensor_shape",
+ ":traceable_stack",
":util",
":versions",
"//tensorflow/core:protos_all_py",
@@ -887,6 +888,17 @@ py_library(
],
)
+# This target is maintained separately from :util to provide separate visibility
+# for legacy users who were granted visibility when the functions were private
+# members of ops.Graph.
+py_library(
+ name = "tf_stack",
+ srcs = ["util/tf_stack.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [],
+)
+
py_library(
name = "tensor_shape",
srcs = ["framework/tensor_shape.py"],
@@ -922,6 +934,16 @@ py_library(
)
py_library(
+ name = "traceable_stack",
+ srcs = ["framework/traceable_stack.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":util",
+ ],
+)
+
+py_library(
name = "versions",
srcs = ["framework/versions.py"],
srcs_version = "PY2AND3",
@@ -1207,6 +1229,21 @@ py_test(
],
)
+py_test(
+ name = "framework_traceable_stack_test",
+ size = "small",
+ srcs = ["framework/traceable_stack_test.py"],
+ main = "framework/traceable_stack_test.py",
+ srcs_version = "PY2AND3",
+ deps = [
+ ":framework_test_lib",
+ ":platform_test",
+ ":test_ops",
+ ":traceable_stack",
+ ":util",
+ ],
+)
+
tf_gen_op_wrapper_py(
name = "test_ops",
out = "framework/test_ops.py",
@@ -4079,6 +4116,7 @@ cuda_py_test(
":math_ops",
"//tensorflow/core:protos_all_py",
],
+ tags = ["no_windows_gpu"],
)
py_test(
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index 3bde62fa1d..38505c0a01 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -349,6 +349,7 @@ tf_py_test(
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:training",
+ "//tensorflow/python/compat:compat",
],
grpc_enabled = True,
)
diff --git a/tensorflow/python/data/kernel_tests/iterator_ops_test.py b/tensorflow/python/data/kernel_tests/iterator_ops_test.py
index 820c167b6b..b434fa7334 100644
--- a/tensorflow/python/data/kernel_tests/iterator_ops_test.py
+++ b/tensorflow/python/data/kernel_tests/iterator_ops_test.py
@@ -25,6 +25,7 @@ import numpy as np
from tensorflow.core.protobuf import cluster_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
+from tensorflow.python.compat import compat as forward_compat
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.ops import readers
@@ -415,6 +416,69 @@ class IteratorTest(test.TestCase):
sess.run(
next_element, feed_dict={handle_placeholder: iterator_4_handle})
+ def testIteratorStringHandleFuture(self):
+ with forward_compat.forward_compatibility_horizon(2018, 8, 4):
+ dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
+ dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40])
+
+ iterator_3 = dataset_3.make_one_shot_iterator()
+ iterator_4 = dataset_4.make_one_shot_iterator()
+
+ handle_placeholder = array_ops.placeholder(dtypes.string, shape=[])
+ feedable_iterator = iterator_ops.Iterator.from_string_handle(
+ handle_placeholder, dataset_3.output_types, dataset_3.output_shapes)
+ next_element = feedable_iterator.get_next()
+
+ self.assertEqual(dataset_3.output_types, feedable_iterator.output_types)
+ self.assertEqual(dataset_4.output_types, feedable_iterator.output_types)
+ self.assertEqual([], feedable_iterator.output_shapes)
+
+ with self.test_session() as sess:
+ iterator_3_handle = sess.run(iterator_3.string_handle())
+ iterator_4_handle = sess.run(iterator_4.string_handle())
+
+ self.assertEqual(
+ 10,
+ sess.run(
+ next_element,
+ feed_dict={handle_placeholder: iterator_4_handle}))
+ self.assertEqual(
+ 1,
+ sess.run(
+ next_element,
+ feed_dict={handle_placeholder: iterator_3_handle}))
+ self.assertEqual(
+ 20,
+ sess.run(
+ next_element,
+ feed_dict={handle_placeholder: iterator_4_handle}))
+ self.assertEqual(
+ 2,
+ sess.run(
+ next_element,
+ feed_dict={handle_placeholder: iterator_3_handle}))
+ self.assertEqual(
+ 30,
+ sess.run(
+ next_element,
+ feed_dict={handle_placeholder: iterator_4_handle}))
+ self.assertEqual(
+ 3,
+ sess.run(
+ next_element,
+ feed_dict={handle_placeholder: iterator_3_handle}))
+ self.assertEqual(
+ 40,
+ sess.run(
+ next_element,
+ feed_dict={handle_placeholder: iterator_4_handle}))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(
+ next_element, feed_dict={handle_placeholder: iterator_3_handle})
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(
+ next_element, feed_dict={handle_placeholder: iterator_4_handle})
+
def testIteratorStringHandleReuseTensorObject(self):
dataset = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
one_shot_iterator = dataset.make_one_shot_iterator()
diff --git a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
index 0ecd821e9e..637bde9ae4 100644
--- a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
@@ -666,6 +666,13 @@ class MapDatasetTest(test.TestCase):
"currently support nested datasets as outputs."):
_ = dataset.map(dataset_ops.Dataset.from_tensor_slices)
+ def testReturnValueError(self):
+ dataset = dataset_ops.Dataset.from_tensors([1.0, 2.0, 3.0])
+ with self.assertRaisesRegexp(
+ TypeError, r"Unsupported return value from function passed to "
+ r"Dataset.map\(\): None."):
+ _ = dataset.map(lambda x: None)
+
class MapDatasetBenchmark(test.Benchmark):
diff --git a/tensorflow/python/data/ops/BUILD b/tensorflow/python/data/ops/BUILD
index fa2e86eab1..f15eb6310f 100644
--- a/tensorflow/python/data/ops/BUILD
+++ b/tensorflow/python/data/ops/BUILD
@@ -40,6 +40,7 @@ py_library(
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:tensor_shape",
+ "//tensorflow/python/compat",
"//tensorflow/python/data/util:convert",
],
)
@@ -54,6 +55,7 @@ py_library(
"//tensorflow/python:framework_ops",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:tensor_shape",
+ "//tensorflow/python/compat",
"//tensorflow/python/data/util:nest",
"//tensorflow/python/data/util:sparse",
"//tensorflow/python/eager:context",
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index d2a8c0f313..88de4b588c 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -24,6 +24,7 @@ import warnings
import numpy as np
import six
+from tensorflow.python.compat import compat
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import random_seed
@@ -107,8 +108,12 @@ class Dataset(object):
"execution is enabled.")
if shared_name is None:
shared_name = ""
- iterator_resource = gen_dataset_ops.iterator(
- container="", shared_name=shared_name, **flat_structure(self))
+ if compat.forward_compatible(2018, 8, 3):
+ iterator_resource = gen_dataset_ops.iterator_v2(
+ container="", shared_name=shared_name, **flat_structure(self))
+ else:
+ iterator_resource = gen_dataset_ops.iterator(
+ container="", shared_name=shared_name, **flat_structure(self))
with ops.colocate_with(iterator_resource):
initializer = gen_dataset_ops.make_iterator(self._as_variant_tensor(),
iterator_resource)
@@ -1425,7 +1430,11 @@ class StructuredFunctionWrapper(object):
flat_shapes.append(component)
flat_types.append(component)
else:
- t = ops.convert_to_tensor(t)
+ try:
+ t = ops.convert_to_tensor(t)
+ except (ValueError, TypeError):
+ raise TypeError("Unsupported return value from function passed to "
+ "%s: %s." % (transformation_name, t))
flat_ret.append(t)
flat_classes.append(ops.Tensor)
flat_shapes.append(t.get_shape())
diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py
index b6dba4e3ca..35de2f2841 100644
--- a/tensorflow/python/data/ops/iterator_ops.py
+++ b/tensorflow/python/data/ops/iterator_ops.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import threading
import warnings
+from tensorflow.python.compat import compat
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import sparse
from tensorflow.python.eager import context
@@ -172,13 +173,32 @@ class Iterator(object):
nest.assert_same_structure(output_types, output_shapes)
if shared_name is None:
shared_name = ""
- iterator_resource = gen_dataset_ops.iterator(
- container="",
- shared_name=shared_name,
- output_types=nest.flatten(
- sparse.as_dense_types(output_types, output_classes)),
- output_shapes=nest.flatten(
- sparse.as_dense_shapes(output_shapes, output_classes)))
+ if compat.forward_compatible(2018, 8, 3):
+ if not ops.get_default_graph()._graph_device_function_stack: # pylint: disable=protected-access
+ with ops.device("/cpu:0"):
+ iterator_resource = gen_dataset_ops.iterator_v2(
+ container="",
+ shared_name=shared_name,
+ output_types=nest.flatten(
+ sparse.as_dense_types(output_types, output_classes)),
+ output_shapes=nest.flatten(
+ sparse.as_dense_shapes(output_shapes, output_classes)))
+ else:
+ iterator_resource = gen_dataset_ops.iterator_v2(
+ container="",
+ shared_name=shared_name,
+ output_types=nest.flatten(
+ sparse.as_dense_types(output_types, output_classes)),
+ output_shapes=nest.flatten(
+ sparse.as_dense_shapes(output_shapes, output_classes)))
+ else:
+ iterator_resource = gen_dataset_ops.iterator(
+ container="",
+ shared_name=shared_name,
+ output_types=nest.flatten(
+ sparse.as_dense_types(output_types, output_classes)),
+ output_shapes=nest.flatten(
+ sparse.as_dense_shapes(output_shapes, output_classes)))
return Iterator(iterator_resource, None, output_types, output_shapes,
output_classes)
@@ -242,12 +262,29 @@ class Iterator(object):
output_classes = nest.map_structure(lambda _: ops.Tensor, output_types)
nest.assert_same_structure(output_types, output_shapes)
string_handle = ops.convert_to_tensor(string_handle, dtype=dtypes.string)
- iterator_resource = gen_dataset_ops.iterator_from_string_handle(
- string_handle,
- output_types=nest.flatten(
- sparse.as_dense_types(output_types, output_classes)),
- output_shapes=nest.flatten(
- sparse.as_dense_shapes(output_shapes, output_classes)))
+ if compat.forward_compatible(2018, 8, 3):
+ if not ops.get_default_graph()._graph_device_function_stack: # pylint: disable=protected-access
+ with ops.device("/cpu:0"):
+ iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
+ string_handle,
+ output_types=nest.flatten(
+ sparse.as_dense_types(output_types, output_classes)),
+ output_shapes=nest.flatten(
+ sparse.as_dense_shapes(output_shapes, output_classes)))
+ else:
+ iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
+ string_handle,
+ output_types=nest.flatten(
+ sparse.as_dense_types(output_types, output_classes)),
+ output_shapes=nest.flatten(
+ sparse.as_dense_shapes(output_shapes, output_classes)))
+ else:
+ iterator_resource = gen_dataset_ops.iterator_from_string_handle(
+ string_handle,
+ output_types=nest.flatten(
+ sparse.as_dense_types(output_types, output_classes)),
+ output_shapes=nest.flatten(
+ sparse.as_dense_shapes(output_shapes, output_classes)))
return Iterator(iterator_resource, None, output_types, output_shapes,
output_classes)
diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD
index c025dc8aa5..27b8ebd362 100644
--- a/tensorflow/python/debug/BUILD
+++ b/tensorflow/python/debug/BUILD
@@ -404,6 +404,7 @@ py_library(
deps = [
":debug_errors",
":debug_fibonacci",
+ ":debug_keras",
":debug_mnist",
":debug_tflearn_iris",
],
@@ -802,6 +803,7 @@ cuda_py_test(
"//tensorflow/python:platform_test",
"//tensorflow/python:variables",
],
+ tags = ["no_windows_gpu"],
)
py_test(
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index 3e3c82e56a..9e0bbce4a1 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -713,10 +713,15 @@ class GradientTape(object):
if self._recording:
self._pop_tape()
- def _push_tape(self):
+ def _push_tape(self, existing_tape=False):
if self._recording:
raise ValueError("Tape is already recording.")
- self._tape = tape.push_new_tape(persistent=self._persistent)
+ if existing_tape:
+ if self._tape is None:
+ raise ValueError("There is no existing tape.")
+ tape.push_tape(self._tape)
+ else:
+ self._tape = tape.push_new_tape(persistent=self._persistent)
self._recording = True
def _pop_tape(self):
@@ -764,7 +769,7 @@ class GradientTape(object):
try:
yield
finally:
- self._push_tape()
+ self._push_tape(existing_tape=True)
def reset(self):
"""Clears all information stored in this tape.
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index ebbd3cd98e..bdda200ff6 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -223,11 +223,23 @@ class BackpropTest(test.TestCase):
def testTapeStopRecording(self):
with backprop.GradientTape() as t:
- x = constant_op.constant(1.0)
+ x = resource_variable_ops.ResourceVariable(1.0)
with t.stop_recording():
y = x * x
self.assertEqual(t.gradient(y, x), None)
+ def testTapeStopStartRecording(self):
+ with backprop.GradientTape(persistent=True) as t:
+ x = resource_variable_ops.ResourceVariable(1.0)
+ x2 = x * 2 # This should be differentiated through.
+ with t.stop_recording():
+ y = x2 * x2
+ z = x2 * x2
+ self.assertEqual(t.gradient(y, x2), None)
+
+ # If the x*2 was not differentiated through, this would be 2.0, not 4.0
+ self.assertEqual(t.gradient(z, x2).numpy(), 4.0)
+
def testTapeReset(self):
with backprop.GradientTape() as t:
v = resource_variable_ops.ResourceVariable(1.0)
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 7a7e8cd219..a6906f9efd 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import collections
import functools
+import threading
import numpy as np
@@ -137,7 +138,7 @@ class CapturingGraph(ops.Graph):
inputs[i] = self.capture(inp)
return super(CapturingGraph, self).create_op(
op_type, inputs, dtypes, input_types, name, attrs, op_def,
- compute_shapes, compute_device)
+ compute_device=compute_device)
# pylint: disable=invalid-name
@@ -469,7 +470,7 @@ class GraphModeFunction(object):
def _construct_backprop_function(self):
"""Constructs the backprop function object for this function."""
- with self._graph.as_default(), context.graph_mode():
+ with self._graph.as_default():
c_known_ops = set()
c_captured_tensors = set()
@@ -656,55 +657,58 @@ def _deterministic_dict_values(kwds):
def _trace_and_define_function(name, func, compiled, args, kwds):
"""Defines and returns graph-mode version of func."""
graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access
- with context.graph_mode():
- captures = {}
- tmp_graph = CapturingGraph(captures)
- # Inherit the graph key, since this is used for matching variables in
- # optimizers.
- tmp_graph._graph_key = graph_key # pylint: disable=protected-access
- # Copy the graph collections to ensure summaries and other things work. This
- # lets the function access (but not mutate) collections of the containing
- # graph, such as the global step and the summary writer collections.
- curr_graph = ops.get_default_graph()
- for collection in curr_graph.collections:
- tmp_graph.get_collection_ref(collection)[:] = curr_graph.get_collection(
- collection)
- with tmp_graph.as_default(), AutomaticControlDependencies() as a:
- func_args = _get_defun_inputs(args)
- func_kwds = _get_defun_inputs(kwds)
-
- def convert(x):
- if x is None:
- return None
- x = ops.convert_to_tensor_or_indexed_slices(x)
- x = a.mark_as_return(x)
- return x
-
- this_tape = tape.push_new_tape()
- try:
- func_outputs = func(*func_args, **func_kwds)
- func_outputs = nest.map_structure(convert, func_outputs)
- finally:
- tape.pop_tape(this_tape)
- variables = this_tape.watched_variables()
-
- # Returning a closed-over tensor as an output does not trigger a
- # call to convert_to_tensor, so we manually capture all such tensors.
- outputs_list = _flatten(func_outputs)
- func_def_outputs = [
- tmp_graph.capture(x) for x in outputs_list
- if x is not None
- ]
-
- ids = list(sorted(captures.keys()))
- if ids:
- extra_inputs, extra_placeholders = zip(* [captures[x] for x in ids])
- else:
- extra_inputs = []
- extra_placeholders = []
- output_shapes = tuple(
- x.shape if isinstance(x, ops.Tensor) else None
- for x in func_def_outputs)
+ captures = {}
+ tmp_graph = CapturingGraph(captures)
+ # Inherit the graph key, since this is used for matching variables in
+ # optimizers.
+ tmp_graph._graph_key = graph_key # pylint: disable=protected-access
+ # Copy the graph collections to ensure summaries and other things work. This
+ # lets the function access (but not mutate) collections of the containing
+ # graph, such as the global step and the summary writer collections.
+ curr_graph = ops.get_default_graph()
+ for collection in curr_graph.collections:
+ tmp_graph.get_collection_ref(collection)[:] = curr_graph.get_collection(
+ collection)
+ if context.executing_eagerly():
+ tmp_graph.seed = context.global_seed()
+ else:
+ tmp_graph.seed = curr_graph.seed
+ with tmp_graph.as_default(), AutomaticControlDependencies() as a:
+ func_args = _get_defun_inputs(args)
+ func_kwds = _get_defun_inputs(kwds)
+
+ def convert(x):
+ if x is None:
+ return None
+ x = ops.convert_to_tensor_or_indexed_slices(x)
+ x = a.mark_as_return(x)
+ return x
+
+ this_tape = tape.push_new_tape()
+ try:
+ func_outputs = func(*func_args, **func_kwds)
+ func_outputs = nest.map_structure(convert, func_outputs)
+ finally:
+ tape.pop_tape(this_tape)
+ variables = this_tape.watched_variables()
+
+ # Returning a closed-over tensor as an output does not trigger a
+ # call to convert_to_tensor, so we manually capture all such tensors.
+ outputs_list = _flatten(func_outputs)
+ func_def_outputs = [
+ tmp_graph.capture(x) for x in outputs_list
+ if x is not None
+ ]
+
+ ids = list(sorted(captures.keys()))
+ if ids:
+ extra_inputs, extra_placeholders = zip(* [captures[x] for x in ids])
+ else:
+ extra_inputs = []
+ extra_placeholders = []
+ output_shapes = tuple(
+ x.shape if isinstance(x, ops.Tensor) else None
+ for x in func_def_outputs)
func_kwds_values = _deterministic_dict_values(func_kwds)
flat_inputs = [
@@ -770,6 +774,11 @@ class _PolymorphicFunction(object):
See the documentation for `defun` for more information on the semantics of
defined functions.
+
+ _PolymorphicFunction class is thread-compatible meaning that minimal
+ usage of defuns (defining and calling) is thread-safe, but if users call other
+ methods or invoke the base `python_function` themselves, external
+ synchronization is necessary.
"""
def __init__(self, python_function, name, compiled=False):
@@ -787,6 +796,8 @@ class _PolymorphicFunction(object):
self._arguments_to_functions = {}
self._variables = []
+ self._lock = threading.Lock()
+
def __get__(self, instance, owner):
"""Makes it possible to defun instance methods."""
del owner
@@ -825,15 +836,16 @@ class _PolymorphicFunction(object):
# signature so we don't improperly capture tensors such as variables.
signature += tuple([context.executing_eagerly() or ops.get_default_graph()])
- if signature not in self._arguments_to_functions:
- graph_function = _trace_and_define_function(
- self._name, self._python_function, self._compiled, args, kwds)
- self._arguments_to_functions[signature] = graph_function
- self._variables.extend(
- [v for v in graph_function.variables if v not in self._variables])
- return graph_function, inputs
- else:
- return self._arguments_to_functions[signature], inputs
+ with self._lock:
+ if signature not in self._arguments_to_functions:
+ graph_function = _trace_and_define_function(
+ self._name, self._python_function, self._compiled, args, kwds)
+ self._arguments_to_functions[signature] = graph_function
+ self._variables.extend(
+ [v for v in graph_function.variables if v not in self._variables])
+ return graph_function, inputs
+ else:
+ return self._arguments_to_functions[signature], inputs
def __call__(self, *args, **kwds):
"""Calls a graph function specialized for this input signature."""
@@ -1296,7 +1308,7 @@ class AutomaticControlDependencies(object):
# Ensures the merge always runs
ops_which_must_run.add(new_merge[0].op)
if inp in last_op_using_resource_tensor:
- # Ensures the switch exectutes after the previous op using the resource.
+ # Ensures the switch executes after the previous op using the resource.
switch_op._add_control_input(last_op_using_resource_tensor[inp]) # pylint: disable=protected-access
# Ensure the next op outside the cond happens after the merge.
last_op_using_resource_tensor[inp] = new_merge[0].op
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 1de25811b4..cdd9fe1760 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -27,8 +27,10 @@ from tensorflow.python.eager import function
from tensorflow.python.eager import tape
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
from tensorflow.python.framework import function as tf_function
from tensorflow.python.framework import ops
+from tensorflow.python.framework import random_seed
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.layers import convolutional
@@ -38,11 +40,13 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
-from tensorflow.python.training import gradient_descent
+from tensorflow.python.training import momentum
+from tensorflow.python.training import training_ops
from tensorflow.python.util import compat
@@ -134,6 +138,18 @@ class FunctionTest(test.TestCase):
out = sq_op(t)
self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
+ def disabled_testRandomSeed(self):
+
+ @function.defun
+ def f():
+ return random_ops.random_normal(())
+
+ random_seed.set_random_seed(1)
+ x = f()
+ self.assertNotEqual(x, f())
+ random_seed.set_random_seed(1)
+ self.assertAllEqual(f(), x)
+
def testNestedInputsDefunOpGraphMode(self):
matmul = function.defun(math_ops.matmul)
@@ -545,10 +561,8 @@ class FunctionTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testFunctionWithResourcesOnDifferentDevices(self):
- # TODO(akshayka): Remove the `skipTest` once we can whitelist ops as
- # safe to be invoked with resources on different devices.
- self.skipTest('The Placer disallows ops with resource inputs '
- 'on different devices.')
+ if not context.context().num_gpus():
+ self.skipTest('No GPUs found.')
with ops.device('/cpu:0'):
v_cpu = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0])
@@ -567,6 +581,44 @@ class FunctionTest(test.TestCase):
expected = self.evaluate(sum_gather())
self.assertAllEqual(expected, self.evaluate(defined()))
+ @test_util.run_in_graph_and_eager_modes
+ def testOpInFunctionWithConflictingResourceInputs(self):
+ if not context.context().num_gpus():
+ self.skipTest('No GPUs found.')
+
+ with ops.device('/cpu:0'):
+ v_cpu = resource_variable_ops.ResourceVariable(
+ [0.0, 1.0, 2.0], name='cpu')
+ v_also_cpu = resource_variable_ops.ResourceVariable(
+ [0.0, 1.0, 2.0], name='also_cpu')
+
+ with ops.device('/gpu:0'):
+ v_gpu = resource_variable_ops.ResourceVariable(
+ [0.0, 1.0, 2.0], name='gpu')
+
+ @function.defun
+ def resource_apply_adam():
+ training_ops.resource_apply_adam(
+ v_cpu.handle,
+ v_gpu.handle,
+ v_also_cpu.handle,
+ 1.0, # beta1_power
+ 1.0, # beta2_power
+ 1.0, # learning_rate
+ 1.0, # beta1
+ 1.0, # beta2
+ 1.0, # epsilon,
+ [1.0, 1.0, 1.0], # grad
+ False) # use_locking
+ return None
+
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError, 'Could not colocate node with its '
+ 'resource and reference inputs.*'):
+ if not context.executing_eagerly():
+ self.evaluate(variables.global_variables_initializer())
+ self.evaluate(resource_apply_adam())
+
def testFunctionHandlesInputsOnDifferentDevices(self):
if not context.context().num_gpus():
self.skipTest('No GPUs found')
@@ -1102,7 +1154,7 @@ class AutomaticControlDependenciesTest(test.TestCase):
def loss(v):
return v**2
- optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)
+ optimizer = momentum.MomentumOptimizer(learning_rate=1.0, momentum=1.0)
@function.defun
def train():
@@ -1119,7 +1171,7 @@ class AutomaticControlDependenciesTest(test.TestCase):
def loss():
return v**2
- optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)
+ optimizer = momentum.MomentumOptimizer(learning_rate=1.0, momentum=1.0)
@function.defun
def train():
diff --git a/tensorflow/python/eager/graph_callable.py b/tensorflow/python/eager/graph_callable.py
index 848adf4fd3..2c6f04d8ad 100644
--- a/tensorflow/python/eager/graph_callable.py
+++ b/tensorflow/python/eager/graph_callable.py
@@ -118,7 +118,7 @@ class _VariableCapturingScope(object):
initializer=None,
regularizer=None,
reuse=None,
- trainable=True,
+ trainable=None,
collections=None,
caching_device=None, # pylint: disable=redefined-outer-name
partitioner=None,
@@ -156,7 +156,7 @@ class _VariableCapturingScope(object):
initializer=None,
regularizer=None,
reuse=None,
- trainable=True,
+ trainable=None,
collections=None,
caching_device=None, # pylint: disable=redefined-outer-name
partitioner=None,
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 57b4dab51c..ec7e2371e9 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -1898,14 +1898,39 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
void MaybeWatchVariable(PyObject* input) {
DCHECK(CheckResourceVariable(input));
- DCHECK(PyObject_HasAttrString(input, "trainable"));
+ DCHECK(PyObject_HasAttrString(input, "_trainable"));
tensorflow::Safe_PyObjectPtr trainable(
- PyObject_GetAttrString(input, "trainable"));
+ PyObject_GetAttrString(input, "_trainable"));
if (trainable.get() == Py_False) return;
TFE_Py_TapeSetWatchVariable(input);
}
+bool CastTensor(const FastPathOpExecInfo& op_exec_info,
+ const TF_DataType& desired_dtype,
+ tensorflow::Safe_TFE_TensorHandlePtr* handle,
+ TF_Status* status) {
+ TF_DataType input_dtype = TFE_TensorHandleDataType(handle->get());
+ TF_DataType output_dtype = input_dtype;
+
+ if (desired_dtype >= 0 && desired_dtype != input_dtype) {
+ *handle = tensorflow::make_safe(
+ tensorflow::EagerCast(op_exec_info.ctx, handle->get(), input_dtype,
+ static_cast<TF_DataType>(desired_dtype), status));
+ if (!status->status.ok()) return false;
+ output_dtype = desired_dtype;
+ }
+
+ if (output_dtype != TF_INT32) {
+ // Note that this is a shallow copy and will share the underlying buffer
+ // if copying to the same device.
+ *handle = tensorflow::make_safe(TFE_TensorHandleCopyToDevice(
+ handle->get(), op_exec_info.ctx, op_exec_info.device_name, status));
+ if (!status->status.ok()) return false;
+ }
+ return true;
+}
+
bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info,
PyObject* input, tensorflow::Safe_PyObjectPtr* output,
TF_Status* status) {
@@ -1938,9 +1963,31 @@ bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info,
TFE_Execute(op, &output_handle, &num_retvals, status);
if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false;
- // Always create the py object (and correctly DECREF it) from the returned
- // value, else the data will leak.
- output->reset(EagerTensorFromHandle(output_handle));
+ if (!PyObject_HasAttrString(input, "_read_dtype")) {
+ // Always create the py object (and correctly DECREF it) from the returned
+ // value, else the data will leak.
+ output->reset(EagerTensorFromHandle(output_handle));
+ } else {
+ // This is a _MixedPrecisionVariable which potentially does casting when
+ // being read.
+ tensorflow::Safe_PyObjectPtr read_dtype(
+ PyObject_GetAttrString(input, "_read_dtype"));
+ int desired_dtype = -1;
+ if (!ParseTypeValue("_read_dtype", read_dtype.get(), status,
+ &desired_dtype)) {
+ return false;
+ }
+
+ auto safe_output_handle = tensorflow::make_safe(output_handle);
+ // Retires output_handle in the future.
+ output_handle = nullptr;
+ if (!CastTensor(parent_op_exec_info,
+ static_cast<TF_DataType>(desired_dtype),
+ &safe_output_handle, status)) {
+ return false;
+ }
+ output->reset(EagerTensorFromHandle(safe_output_handle.release()));
+ }
// TODO(nareshmodi): Should we run post exec callbacks here?
if (parent_op_exec_info.run_gradient_callback) {
@@ -2010,27 +2057,13 @@ bool ConvertToTensor(
}
}
- TF_DataType handle_dtype = TFE_TensorHandleDataType(handle.get());
- if (desired_dtype >= 0 && desired_dtype != handle_dtype) {
- handle = tensorflow::make_safe(
- tensorflow::EagerCast(op_exec_info.ctx, handle.get(), handle_dtype,
- static_cast<TF_DataType>(desired_dtype), status));
- if (!status->status.ok()) return false;
-
- handle_dtype = TFE_TensorHandleDataType(handle.get());
- }
-
- if (handle_dtype != TF_INT32) {
- // Note that this is a shallow copy and will share the underlying buffer
- // if copying to the same device.
- handle = tensorflow::make_safe(TFE_TensorHandleCopyToDevice(
- handle.get(), op_exec_info.ctx, op_exec_info.device_name, status));
- if (!status->status.ok()) return false;
+ if (!CastTensor(op_exec_info, static_cast<TF_DataType>(desired_dtype),
+ &handle, status)) {
+ return false;
}
-
+ TF_DataType output_dtype = TFE_TensorHandleDataType(handle.get());
output_handle->reset(EagerTensorFromHandle(handle.release()));
-
- dtype_setter(handle_dtype);
+ dtype_setter(output_dtype);
return true;
}
diff --git a/tensorflow/python/eager/pywrap_tfe_test.py b/tensorflow/python/eager/pywrap_tfe_test.py
index faaae40b3f..fd8ab695b8 100644
--- a/tensorflow/python/eager/pywrap_tfe_test.py
+++ b/tensorflow/python/eager/pywrap_tfe_test.py
@@ -23,6 +23,7 @@ from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
@@ -71,6 +72,25 @@ class Tests(test.TestCase):
@test_util.assert_no_new_tensors
@test_util.assert_no_garbage_created
+ def testFastpathExecute_MixedPrecisionVariableMatMulCorrectResponse(self):
+ ctx = context.context()
+ a_2_by_2 = constant_op.constant(1.0, shape=[2, 2])
+ a_2_by_2_fp16 = math_ops.cast(a_2_by_2, dtype=dtypes.float16)
+ m = resource_variable_ops.ResourceVariable(a_2_by_2)
+ m = resource_variable_ops._MixedPrecisionVariable(
+ m, read_dtype=dtypes.float16)
+ x = pywrap_tensorflow.TFE_Py_FastPathExecute(
+ ctx._handle, ctx.device_name, "MatMul", None, None, m, m, "transpose_a",
+ False, "transpose_b", False)
+ y = pywrap_tensorflow.TFE_Py_FastPathExecute(
+ ctx._handle, ctx.device_name, "MatMul", None, None, a_2_by_2_fp16,
+ a_2_by_2_fp16, "transpose_a", False, "transpose_b", False)
+
+ self.assertEqual(x.dtype, dtypes.float16)
+ self.assertAllEqual(x, y)
+
+ @test_util.assert_no_new_tensors
+ @test_util.assert_no_garbage_created
def testFastpathExecute_TapeWrite(self):
ctx = context.context()
with backprop.GradientTape(persistent=True) as tape:
@@ -98,6 +118,29 @@ class Tests(test.TestCase):
self.assertAllEqual(dz_dy.numpy(),
constant_op.constant(4.0, shape=[2, 2]).numpy())
+ @test_util.assert_no_new_tensors
+ @test_util.assert_no_garbage_created
+ def testFastpathExecute_MixedPrecisionVariableTapeWrite(self):
+ ctx = context.context()
+ with backprop.GradientTape(persistent=True) as tape:
+ a_2_by_2 = constant_op.constant(
+ [[1.0, 2.0], [3.0, 4.0]], dtype=dtypes.float32)
+ a_2_by_2_fp16 = math_ops.cast(a_2_by_2, dtype=dtypes.float16)
+ m1 = resource_variable_ops.ResourceVariable(a_2_by_2)
+ m2 = resource_variable_ops._MixedPrecisionVariable(
+ m1, read_dtype=dtypes.float16)
+ tape.watch(m2)
+ z = pywrap_tensorflow.TFE_Py_FastPathExecute(
+ ctx._handle, ctx.device_name, "MatMul", None, None, a_2_by_2_fp16, m2,
+ "transpose_a", False, "transpose_b", False)
+ dz_dy = tape.gradient(z, [m2])[0]
+ self.assertEqual(dz_dy.dtype, dtypes.float16)
+
+ expected_grads = math_ops.matmul(
+ array_ops.transpose(a_2_by_2_fp16),
+ constant_op.constant(1., shape=[2, 2], dtype=dtypes.float16)).numpy()
+ self.assertAllEqual(dz_dy.numpy(), expected_grads)
+
# Tests homogeneous list op
@test_util.assert_no_new_tensors
@test_util.assert_no_garbage_created
diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD
index 8ee38d35cc..6c415b1bf2 100644
--- a/tensorflow/python/estimator/BUILD
+++ b/tensorflow/python/estimator/BUILD
@@ -707,6 +707,14 @@ py_library(
)
py_library(
+ name = "expect_h5py_installed",
+ # This is a dummy rule used as a numpy dependency in open-source.
+ # We expect h5py to already be installed on the system, e.g. via
+ # `pip install h5py'
+ visibility = ["//visibility:public"],
+)
+
+py_library(
name = "expect_six_installed",
# This is a dummy rule used as a numpy dependency in open-source.
# We expect six to already be installed on the system, e.g. via
diff --git a/tensorflow/python/estimator/api/BUILD b/tensorflow/python/estimator/api/BUILD
index aa5a29e6dd..a75fa7d0ae 100644
--- a/tensorflow/python/estimator/api/BUILD
+++ b/tensorflow/python/estimator/api/BUILD
@@ -6,13 +6,14 @@ package(
licenses(["notice"]) # Apache 2.0
-load("//tensorflow/tools/api/generator:api_gen.bzl", "gen_api_init_files")
-load("//tensorflow/tools/api/generator:api_gen.bzl", "ESTIMATOR_API_INIT_FILES")
+load("//tensorflow/python/tools/api/generator:api_gen.bzl", "gen_api_init_files")
+load("//tensorflow/python/tools/api/generator:api_gen.bzl", "ESTIMATOR_API_INIT_FILES")
gen_api_init_files(
name = "estimator_python_api_gen",
api_name = "estimator",
output_files = ESTIMATOR_API_INIT_FILES,
+ output_package = "tensorflow.python.estimator.api",
package = "tensorflow.python.estimator",
package_dep = "//tensorflow/python/estimator:estimator_py",
)
diff --git a/tensorflow/python/estimator/canned/baseline_test.py b/tensorflow/python/estimator/canned/baseline_test.py
index 7bf2e62da9..e46a3a156d 100644
--- a/tensorflow/python/estimator/canned/baseline_test.py
+++ b/tensorflow/python/estimator/canned/baseline_test.py
@@ -154,6 +154,8 @@ class BaselineRegressorEvaluationTest(test.TestCase):
self.assertDictEqual({
metric_keys.MetricKeys.LOSS: 9.,
metric_keys.MetricKeys.LOSS_MEAN: 9.,
+ metric_keys.MetricKeys.PREDICTION_MEAN: 13.,
+ metric_keys.MetricKeys.LABEL_MEAN: 10.,
ops.GraphKeys.GLOBAL_STEP: 100
}, eval_metrics)
@@ -176,6 +178,8 @@ class BaselineRegressorEvaluationTest(test.TestCase):
self.assertDictEqual({
metric_keys.MetricKeys.LOSS: 18.,
metric_keys.MetricKeys.LOSS_MEAN: 9.,
+ metric_keys.MetricKeys.PREDICTION_MEAN: 13.,
+ metric_keys.MetricKeys.LABEL_MEAN: 10.,
ops.GraphKeys.GLOBAL_STEP: 100
}, eval_metrics)
@@ -204,6 +208,8 @@ class BaselineRegressorEvaluationTest(test.TestCase):
self.assertDictEqual({
metric_keys.MetricKeys.LOSS: 27.,
metric_keys.MetricKeys.LOSS_MEAN: 9.,
+ metric_keys.MetricKeys.PREDICTION_MEAN: 13.,
+ metric_keys.MetricKeys.LABEL_MEAN: 10.,
ops.GraphKeys.GLOBAL_STEP: 100
}, eval_metrics)
@@ -229,7 +235,9 @@ class BaselineRegressorEvaluationTest(test.TestCase):
self.assertItemsEqual(
(metric_keys.MetricKeys.LOSS, metric_keys.MetricKeys.LOSS_MEAN,
- ops.GraphKeys.GLOBAL_STEP), eval_metrics.keys())
+ metric_keys.MetricKeys.PREDICTION_MEAN,
+ metric_keys.MetricKeys.LABEL_MEAN, ops.GraphKeys.GLOBAL_STEP),
+ eval_metrics.keys())
# Logit is bias which is [46, 58]
self.assertAlmostEqual(0, eval_metrics[metric_keys.MetricKeys.LOSS])
diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py
index a22e9745c1..3c832c7569 100644
--- a/tensorflow/python/estimator/canned/boosted_trees.py
+++ b/tensorflow/python/estimator/canned/boosted_trees.py
@@ -669,6 +669,8 @@ def _bt_model_fn(
name='wait_until_n_batches_for_bias_accumulated')
return center_bias_op
+ else:
+ return control_flow_ops.no_op()
def grow_not_in_mem():
"""Accumulates the data and grows a layer when ready."""
@@ -715,6 +717,8 @@ def _bt_model_fn(
name='wait_until_n_batches_accumulated')
return grow_model
+ else:
+ return control_flow_ops.no_op()
update_model = control_flow_ops.cond(
center_bias_var, center_bias_not_in_mem, grow_not_in_mem)
diff --git a/tensorflow/python/estimator/canned/dnn_testing_utils.py b/tensorflow/python/estimator/canned/dnn_testing_utils.py
index ba17821259..de226ed0ef 100644
--- a/tensorflow/python/estimator/canned/dnn_testing_utils.py
+++ b/tensorflow/python/estimator/canned/dnn_testing_utils.py
@@ -1271,6 +1271,8 @@ class BaseDNNRegressorEvaluateTest(object):
self.assertAllClose({
metric_keys.MetricKeys.LOSS: expected_loss,
metric_keys.MetricKeys.LOSS_MEAN: expected_loss,
+ metric_keys.MetricKeys.PREDICTION_MEAN: -2.08,
+ metric_keys.MetricKeys.LABEL_MEAN: 1.0,
ops.GraphKeys.GLOBAL_STEP: global_step
}, dnn_regressor.evaluate(input_fn=_input_fn, steps=1))
@@ -1301,6 +1303,8 @@ class BaseDNNRegressorEvaluateTest(object):
self.assertAllClose({
metric_keys.MetricKeys.LOSS: expected_loss,
metric_keys.MetricKeys.LOSS_MEAN: expected_loss / label_dimension,
+ metric_keys.MetricKeys.PREDICTION_MEAN: 0.39 / 3.0,
+ metric_keys.MetricKeys.LABEL_MEAN: 0.5 / 3.0,
ops.GraphKeys.GLOBAL_STEP: global_step
}, dnn_regressor.evaluate(input_fn=_input_fn, steps=1))
diff --git a/tensorflow/python/estimator/canned/head.py b/tensorflow/python/estimator/canned/head.py
index b74ef1015c..da9a64c2bc 100644
--- a/tensorflow/python/estimator/canned/head.py
+++ b/tensorflow/python/estimator/canned/head.py
@@ -1398,15 +1398,21 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
weights=weights,
processed_labels=labels)
- def _eval_metric_ops(self, weights, unreduced_loss, regularization_loss):
+ def _eval_metric_ops(self, predicted_value, labels, weights, unreduced_loss,
+ regularization_loss):
"""Returns the Eval metric ops."""
keys = metric_keys.MetricKeys
# Estimator already adds a metric for loss.
eval_metric_ops = {
_summary_key(self._name, keys.LOSS_MEAN):
- metrics_lib.mean(
- values=unreduced_loss,
- weights=weights)
+ metrics_lib.mean(values=unreduced_loss, weights=weights),
+ _summary_key(self._name, keys.PREDICTION_MEAN):
+ _predictions_mean(
+ predictions=predicted_value,
+ weights=weights,
+ name=keys.PREDICTION_MEAN),
+ _summary_key(self._name, keys.LABEL_MEAN):
+ metrics_lib.mean(values=labels, weights=weights)
}
if regularization_loss is not None:
regularization_loss_key = _summary_key(
@@ -1489,13 +1495,13 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
predictions=predictions,
loss=regularized_training_loss,
eval_metrics=_create_eval_metrics_tuple(
- self._eval_metric_ops,
- {
+ self._eval_metric_ops, {
+ 'predicted_value': predicted_value,
+ 'labels': labels,
'weights': weights,
'unreduced_loss': unreduced_loss,
'regularization_loss': regularization_loss,
- }
- ))
+ }))
# Train.
if optimizer is not None:
diff --git a/tensorflow/python/estimator/canned/head_test.py b/tensorflow/python/estimator/canned/head_test.py
index 08ce5ca8e8..bd2e0ae943 100644
--- a/tensorflow/python/estimator/canned/head_test.py
+++ b/tensorflow/python/estimator/canned/head_test.py
@@ -3103,8 +3103,10 @@ class RegressionHead(test.TestCase):
self.assertItemsEqual((prediction_key,), spec.predictions.keys())
self.assertEqual(dtypes.float32, spec.predictions[prediction_key].dtype)
self.assertEqual(dtypes.float32, spec.loss.dtype)
- self.assertItemsEqual(
- (metric_keys.MetricKeys.LOSS_MEAN,), spec.eval_metric_ops.keys())
+ self.assertItemsEqual((metric_keys.MetricKeys.LOSS_MEAN,
+ metric_keys.MetricKeys.PREDICTION_MEAN,
+ metric_keys.MetricKeys.LABEL_MEAN),
+ spec.eval_metric_ops.keys())
self.assertIsNone(spec.train_op)
self.assertIsNone(spec.export_outputs)
_assert_no_hooks(self, spec)
@@ -3140,6 +3142,9 @@ class RegressionHead(test.TestCase):
expected_metric_keys = [
'{}/some_regression_head'.format(metric_keys.MetricKeys.LOSS_MEAN),
+ '{}/some_regression_head'.format(
+ metric_keys.MetricKeys.PREDICTION_MEAN),
+ '{}/some_regression_head'.format(metric_keys.MetricKeys.LABEL_MEAN),
]
self.assertItemsEqual(expected_metric_keys, spec.eval_metric_ops.keys())
@@ -3170,6 +3175,8 @@ class RegressionHead(test.TestCase):
expected_metrics = {
keys.LOSS_MEAN: expected_unregularized_loss,
keys.LOSS_REGULARIZATION: expected_regularization_loss,
+ keys.PREDICTION_MEAN: (45 + 41) / 2.0,
+ keys.LABEL_MEAN: (43 + 44) / 2.0,
}
# Assert predictions, loss, and metrics.
@@ -3471,8 +3478,10 @@ class RegressionHead(test.TestCase):
self.assertItemsEqual((prediction_key,), spec.predictions.keys())
self.assertEqual(dtypes.float32, spec.predictions[prediction_key].dtype)
self.assertEqual(dtypes.float32, spec.loss.dtype)
- self.assertItemsEqual(
- (metric_keys.MetricKeys.LOSS_MEAN,), spec.eval_metric_ops.keys())
+ self.assertItemsEqual((metric_keys.MetricKeys.LOSS_MEAN,
+ metric_keys.MetricKeys.PREDICTION_MEAN,
+ metric_keys.MetricKeys.LABEL_MEAN),
+ spec.eval_metric_ops.keys())
self.assertIsNone(spec.train_op)
self.assertIsNone(spec.export_outputs)
_assert_no_hooks(self, spec)
@@ -3700,8 +3709,10 @@ class RegressionHead(test.TestCase):
self.assertItemsEqual((prediction_key,), spec.predictions.keys())
self.assertEqual(dtypes.float32, spec.predictions[prediction_key].dtype)
self.assertEqual(dtypes.float32, spec.loss.dtype)
- self.assertItemsEqual(
- (metric_keys.MetricKeys.LOSS_MEAN,), spec.eval_metric_ops.keys())
+ self.assertItemsEqual((metric_keys.MetricKeys.LOSS_MEAN,
+ metric_keys.MetricKeys.PREDICTION_MEAN,
+ metric_keys.MetricKeys.LABEL_MEAN),
+ spec.eval_metric_ops.keys())
self.assertIsNone(spec.train_op)
self.assertIsNone(spec.export_outputs)
_assert_no_hooks(self, spec)
@@ -3832,7 +3843,13 @@ class RegressionHead(test.TestCase):
# losses = [1*(35-45)^2, .1*(42-41)^2, 1.5*(45-44)^2] = [100, .1, 1.5]
# loss = sum(losses) = 100+.1+1.5 = 101.6
# loss_mean = loss/(1+.1+1.5) = 101.6/2.6 = 39.076923
- expected_metrics = {metric_keys.MetricKeys.LOSS_MEAN: 39.076923}
+ expected_metrics = {
+ metric_keys.MetricKeys.LOSS_MEAN:
+ 39.076923,
+ metric_keys.MetricKeys.PREDICTION_MEAN:
+ (45 + 41 * 0.1 + 44 * 1.5) / 2.6,
+ metric_keys.MetricKeys.LABEL_MEAN: (35 + 42 * 0.1 + 45 * 1.5) / 2.6,
+ }
# Assert spec contains expected tensors.
self.assertEqual(dtypes.float32, spec.loss.dtype)
diff --git a/tensorflow/python/estimator/canned/linear_testing_utils.py b/tensorflow/python/estimator/canned/linear_testing_utils.py
index 9e9c2f7c4b..c3934c7a80 100644
--- a/tensorflow/python/estimator/canned/linear_testing_utils.py
+++ b/tensorflow/python/estimator/canned/linear_testing_utils.py
@@ -261,6 +261,8 @@ class BaseLinearRegressorEvaluationTest(object):
self.assertDictEqual({
metric_keys.MetricKeys.LOSS: 9.,
metric_keys.MetricKeys.LOSS_MEAN: 9.,
+ metric_keys.MetricKeys.PREDICTION_MEAN: 13.,
+ metric_keys.MetricKeys.LABEL_MEAN: 10.,
ops.GraphKeys.GLOBAL_STEP: 100
}, eval_metrics)
@@ -286,6 +288,8 @@ class BaseLinearRegressorEvaluationTest(object):
self.assertDictEqual({
metric_keys.MetricKeys.LOSS: 18.,
metric_keys.MetricKeys.LOSS_MEAN: 9.,
+ metric_keys.MetricKeys.PREDICTION_MEAN: 13.,
+ metric_keys.MetricKeys.LABEL_MEAN: 10.,
ops.GraphKeys.GLOBAL_STEP: 100
}, eval_metrics)
@@ -316,6 +320,8 @@ class BaseLinearRegressorEvaluationTest(object):
self.assertDictEqual({
metric_keys.MetricKeys.LOSS: 27.,
metric_keys.MetricKeys.LOSS_MEAN: 9.,
+ metric_keys.MetricKeys.PREDICTION_MEAN: 13.,
+ metric_keys.MetricKeys.LABEL_MEAN: 10.,
ops.GraphKeys.GLOBAL_STEP: 100
}, eval_metrics)
@@ -346,7 +352,9 @@ class BaseLinearRegressorEvaluationTest(object):
self.assertItemsEqual(
(metric_keys.MetricKeys.LOSS, metric_keys.MetricKeys.LOSS_MEAN,
- ops.GraphKeys.GLOBAL_STEP), eval_metrics.keys())
+ metric_keys.MetricKeys.PREDICTION_MEAN,
+ metric_keys.MetricKeys.LABEL_MEAN, ops.GraphKeys.GLOBAL_STEP),
+ eval_metrics.keys())
# Logit is
# [2., 4., 5.] * [1.0, 2.0] + [7.0, 8.0] = [39, 50] + [7.0, 8.0]
@@ -383,7 +391,9 @@ class BaseLinearRegressorEvaluationTest(object):
eval_metrics = est.evaluate(input_fn=input_fn, steps=1)
self.assertItemsEqual(
(metric_keys.MetricKeys.LOSS, metric_keys.MetricKeys.LOSS_MEAN,
- ops.GraphKeys.GLOBAL_STEP), eval_metrics.keys())
+ metric_keys.MetricKeys.PREDICTION_MEAN,
+ metric_keys.MetricKeys.LABEL_MEAN, ops.GraphKeys.GLOBAL_STEP),
+ eval_metrics.keys())
# Logit is [(20. * 10.0 + 4 * 2.0 + 5.0), (40. * 10.0 + 8 * 2.0 + 5.0)] =
# [213.0, 421.0], while label is [213., 421.]. Loss = 0.
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index 350a95eea1..253716b43e 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -576,7 +576,9 @@ class Estimator(object):
allowed_overrides = set([
'_call_input_fn', '_create_global_step',
'_convert_train_steps_to_hooks', '_convert_eval_steps_to_hooks',
- '_tf_api_names', '_estimator_api_names', '_estimator_api_constants',
+ '_tf_api_names', '_tf_api_names_v1', '_estimator_api_names',
+ '_estimator_api_names_v1', '_estimator_api_constants',
+ '_estimator_api_constants_v1',
'_validate_features_in_predict_input',
'_call_model_fn', '_add_meta_graph_for_mode'
])
diff --git a/tensorflow/python/estimator/keras.py b/tensorflow/python/estimator/keras.py
index cb37f99704..076359b503 100644
--- a/tensorflow/python/estimator/keras.py
+++ b/tensorflow/python/estimator/keras.py
@@ -39,7 +39,6 @@ from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics as metrics_module
-from tensorflow.python.ops import variables as variables_module
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.training import distribute as distribute_lib
@@ -71,16 +70,22 @@ def _convert_tensor(x):
return x
-def _any_variable_initialized():
- """Check if any variable has been initialized in the Keras model.
+def _any_weight_initialized(keras_model):
+ """Check if any weights has been initialized in the Keras model.
+
+ Args:
+ keras_model: An instance of compiled keras model.
Returns:
- boolean, True if at least one variable has been initialized, else False.
+ boolean, True if at least one weight has been initialized, else False.
+ Currently keras initialize all weights at get_session().
"""
- variables = variables_module.global_variables()
- for v in variables:
- if getattr(v, '_keras_initialized', False):
- return True
+ if keras_model is None:
+ return False
+ for layer in keras_model.layers:
+ for weight in layer.weights:
+ if hasattr(weight, '_keras_initialized'):
+ return True
return False
@@ -520,7 +525,7 @@ def model_to_estimator(keras_model=None,
keras_model_fn, model_dir=model_dir, config=config)
# Check if we need to call get_weights:
- if _any_variable_initialized():
+ if _any_weight_initialized(keras_model):
keras_weights = keras_model.get_weights()
# Warn if config passed to estimator tries to update GPUOptions. If a
# session has already been created, the GPUOptions passed to the first
diff --git a/tensorflow/python/estimator/keras_test.py b/tensorflow/python/estimator/keras_test.py
index 5e094ae92b..7a3c5a9bf1 100644
--- a/tensorflow/python/estimator/keras_test.py
+++ b/tensorflow/python/estimator/keras_test.py
@@ -32,7 +32,6 @@ from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras import testing_utils
-from tensorflow.python.keras.applications import mobilenet
from tensorflow.python.keras.optimizers import SGD
from tensorflow.python.ops.parsing_ops import gen_parsing_ops
from tensorflow.python.platform import gfile
@@ -60,9 +59,9 @@ def simple_sequential_model():
return model
-def simple_functional_model():
+def simple_functional_model(activation='relu'):
a = keras.layers.Input(shape=_INPUT_SIZE)
- b = keras.layers.Dense(16, activation='relu')(a)
+ b = keras.layers.Dense(16, activation=activation)(a)
b = keras.layers.Dropout(0.1)(b)
b = keras.layers.Dense(_NUM_CLASS, activation='softmax')(b)
model = keras.models.Model(inputs=[a], outputs=[b])
@@ -204,6 +203,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
writer_cache.FileWriterCache.clear()
gfile.DeleteRecursively(self._config.model_dir)
+ @test_util.run_in_graph_and_eager_modes
def test_train_with_tf_optimizer(self):
for model_type in ['sequential', 'functional']:
keras_model, (_, _), (
@@ -231,6 +231,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
writer_cache.FileWriterCache.clear()
gfile.DeleteRecursively(self._config.model_dir)
+ @test_util.run_in_graph_and_eager_modes
def test_train_with_subclassed_model(self):
keras_model, (_, _), (
_, _), train_input_fn, eval_input_fn = get_resource_for_simple_model(
@@ -472,21 +473,25 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
est_keras.train(input_fn=invald_output_name_input_fn, steps=100)
def test_custom_objects(self):
- keras_mobile = mobilenet.MobileNet(weights=None)
- keras_mobile.compile(loss='categorical_crossentropy', optimizer='adam')
+
+ def relu6(x):
+ return keras.backend.relu(x, max_value=6)
+
+ keras_model = simple_functional_model(activation=relu6)
+ keras_model.compile(loss='categorical_crossentropy', optimizer='adam')
custom_objects = {
- 'relu6': mobilenet.relu6,
- 'DepthwiseConv2D': mobilenet.DepthwiseConv2D
+ 'relu6': relu6
}
+
with self.assertRaisesRegexp(ValueError, 'relu6'):
with self.test_session():
keras_lib.model_to_estimator(
- keras_model=keras_mobile,
+ keras_model=keras_model,
model_dir=tempfile.mkdtemp(dir=self._base_dir))
with self.test_session():
keras_lib.model_to_estimator(
- keras_model=keras_mobile,
+ keras_model=keras_model,
model_dir=tempfile.mkdtemp(dir=self._base_dir),
custom_objects=custom_objects)
diff --git a/tensorflow/python/estimator/run_config.py b/tensorflow/python/estimator/run_config.py
index 3d60c63b68..aa594af2e4 100644
--- a/tensorflow/python/estimator/run_config.py
+++ b/tensorflow/python/estimator/run_config.py
@@ -485,7 +485,16 @@ class RunConfig(object):
self._init_distributed_setting_from_environment_var(tf_config)
- # Get session_config only for distributed mode (cluster_spec is present).
+ self._maybe_overwrite_session_config_for_distributed_training()
+
+ def _maybe_overwrite_session_config_for_distributed_training(self):
+ """Overwrites the session_config for distributed training.
+
+ The default overwrite is optimized for between-graph training. Subclass
+ should override this method if necessary.
+ """
+ # Get session_config only for between-graph distributed mode (cluster_spec
+ # is present).
if not self._session_config and self._cluster_spec:
RunConfig._replace(
self,
diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py
index 5730101092..f5ac79ced2 100644
--- a/tensorflow/python/estimator/training.py
+++ b/tensorflow/python/estimator/training.py
@@ -312,10 +312,10 @@ def train_and_evaluate(estimator, train_spec, eval_spec):
# hidden_units=[1024, 512, 256])
# Input pipeline for train and evaluate.
- def train_input_fn: # returns x, y
+ def train_input_fn(): # returns x, y
# please shuffle the data.
pass
- def eval_input_fn_eval: # returns x, y
+ def eval_input_fn(): # returns x, y
pass
train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=1000)
diff --git a/tensorflow/python/feature_column/BUILD b/tensorflow/python/feature_column/BUILD
index 295d4ca094..80707030e6 100644
--- a/tensorflow/python/feature_column/BUILD
+++ b/tensorflow/python/feature_column/BUILD
@@ -48,6 +48,39 @@ py_library(
],
)
+py_library(
+ name = "feature_column_v2",
+ srcs = ["feature_column_v2.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:check_ops",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:embedding_ops",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:init_ops",
+ "//tensorflow/python:lookup_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:nn_ops",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:sparse_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:string_ops",
+ "//tensorflow/python:template",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python:training",
+ "//tensorflow/python:util",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/keras",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
+ ],
+)
+
filegroup(
name = "vocabulary_testdata",
srcs = [
@@ -92,3 +125,38 @@ py_test(
"//tensorflow/python/estimator:numpy_io",
],
)
+
+py_test(
+ name = "feature_column_v2_test",
+ srcs = ["feature_column_v2_test.py"],
+ data = [":vocabulary_testdata"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_cuda_on_cpu_tap",
+ "no_pip",
+ ],
+ deps = [
+ ":feature_column_py",
+ ":feature_column_v2",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:lookup_ops",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:partitioned_variables",
+ "//tensorflow/python:session",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/eager:backprop",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python/estimator:numpy_io",
+ "//third_party/py/numpy",
+ ],
+)
diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py
new file mode 100644
index 0000000000..b4dd23f58d
--- /dev/null
+++ b/tensorflow/python/feature_column/feature_column_v2.py
@@ -0,0 +1,3600 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""This API defines FeatureColumn abstraction.
+
+FeatureColumns provide a high level abstraction for ingesting and representing
+features. FeatureColumns are also the primary way of encoding features for
+canned @{tf.estimator.Estimator}s.
+
+When using FeatureColumns with `Estimators`, the type of feature column you
+should choose depends on (1) the feature type and (2) the model type.
+
+1. Feature type:
+
+ * Continuous features can be represented by `numeric_column`.
+ * Categorical features can be represented by any `categorical_column_with_*`
+ column:
+ - `categorical_column_with_vocabulary_list`
+ - `categorical_column_with_vocabulary_file`
+ - `categorical_column_with_hash_bucket`
+ - `categorical_column_with_identity`
+ - `weighted_categorical_column`
+
+2. Model type:
+
+ * Deep neural network models (`DNNClassifier`, `DNNRegressor`).
+
+ Continuous features can be directly fed into deep neural network models.
+
+ age_column = numeric_column("age")
+
+ To feed sparse features into DNN models, wrap the column with
+ `embedding_column` or `indicator_column`. `indicator_column` is recommended
+ for features with only a few possible values. For features with many
+ possible values, to reduce the size of your model, `embedding_column` is
+ recommended.
+
+ embedded_dept_column = embedding_column(
+ categorical_column_with_vocabulary_list(
+ "department", ["math", "philosophy", ...]), dimension=10)
+
+ * Wide (aka linear) models (`LinearClassifier`, `LinearRegressor`).
+
+ Sparse features can be fed directly into linear models. They behave like an
+ indicator column but with an efficient implementation.
+
+ dept_column = categorical_column_with_vocabulary_list("department",
+ ["math", "philosophy", "english"])
+
+ It is recommended that continuous features be bucketized before being
+ fed into linear models.
+
+ bucketized_age_column = bucketized_column(
+ source_column=age_column,
+ boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65])
+
+ Sparse features can be crossed (also known as conjuncted or combined) in
+ order to form non-linearities, and then fed into linear models.
+
+ cross_dept_age_column = crossed_column(
+ columns=["department", bucketized_age_column],
+ hash_bucket_size=1000)
+
+Example of building canned `Estimator`s using FeatureColumns:
+
+ ```python
+ # Define features and transformations
+ deep_feature_columns = [age_column, embedded_dept_column]
+ wide_feature_columns = [dept_column, bucketized_age_column,
+ cross_dept_age_column]
+
+ # Build deep model
+ estimator = DNNClassifier(
+ feature_columns=deep_feature_columns,
+ hidden_units=[500, 250, 50])
+ estimator.train(...)
+
+ # Or build a wide model
+ estimator = LinearClassifier(
+ feature_columns=wide_feature_columns)
+ estimator.train(...)
+
+ # Or build a wide and deep model!
+ estimator = DNNLinearCombinedClassifier(
+ linear_feature_columns=wide_feature_columns,
+ dnn_feature_columns=deep_feature_columns,
+ dnn_hidden_units=[500, 250, 50])
+ estimator.train(...)
+ ```
+
+
+FeatureColumns can also be transformed into a generic input layer for
+custom models using `input_layer`.
+
+Example of building model using FeatureColumns, this can be used in a
+`model_fn` which is given to the {tf.estimator.Estimator}:
+
+ ```python
+ # Building model via layers
+
+ deep_feature_columns = [age_column, embedded_dept_column]
+ columns_to_tensor = parse_feature_columns_from_examples(
+ serialized=my_data,
+ feature_columns=deep_feature_columns)
+ first_layer = input_layer(
+ features=columns_to_tensor,
+ feature_columns=deep_feature_columns)
+ second_layer = fully_connected(first_layer, ...)
+ ```
+
+NOTE: Functions prefixed with "_" indicate experimental or private parts of
+the API subject to change, and should not be relied upon!
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+import collections
+import math
+
+import numpy as np
+import six
+
+
+from tensorflow.python.eager import context
+from tensorflow.python.feature_column import feature_column as fc_old
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.keras.engine import training
+from tensorflow.python.layers import base
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import embedding_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import lookup_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import sparse_ops
+from tensorflow.python.ops import string_ops
+from tensorflow.python.ops import template
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import checkpoint_utils
+from tensorflow.python.util import nest
+
+
+def _internal_input_layer(features,
+ feature_columns,
+ weight_collections=None,
+ trainable=True,
+ cols_to_vars=None,
+ scope=None):
+ """See input_layer. `scope` is a name or variable scope to use."""
+
+ feature_columns = fc_old._normalize_feature_columns(feature_columns) # pylint: disable=protected-access
+ for column in feature_columns:
+ if not isinstance(column, fc_old._DenseColumn): # pylint: disable=protected-access
+ raise ValueError(
+ 'Items of feature_columns must be a _DenseColumn. '
+ 'You can wrap a categorical column with an '
+ 'embedding_column or indicator_column. Given: {}'.format(column))
+ weight_collections = list(weight_collections or [])
+ if ops.GraphKeys.GLOBAL_VARIABLES not in weight_collections:
+ weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES)
+ if ops.GraphKeys.MODEL_VARIABLES not in weight_collections:
+ weight_collections.append(ops.GraphKeys.MODEL_VARIABLES)
+
+ # a non-None `scope` can allow for variable reuse, when, e.g., this function
+ # is wrapped by a `make_template`.
+ with variable_scope.variable_scope(
+ scope, default_name='input_layer', values=features.values()):
+ builder = fc_old._LazyBuilder(features) # pylint: disable=protected-access
+ output_tensors = []
+ ordered_columns = []
+ for column in sorted(feature_columns, key=lambda x: x.name):
+ ordered_columns.append(column)
+ with variable_scope.variable_scope(
+ None, default_name=column._var_scope_name): # pylint: disable=protected-access
+ tensor = column._get_dense_tensor( # pylint: disable=protected-access
+ builder,
+ weight_collections=weight_collections,
+ trainable=trainable)
+ num_elements = column._variable_shape.num_elements() # pylint: disable=protected-access
+ batch_size = array_ops.shape(tensor)[0]
+ output_tensors.append(
+ array_ops.reshape(tensor, shape=(batch_size, num_elements)))
+ if cols_to_vars is not None:
+ # Retrieve any variables created (some _DenseColumn's don't create
+ # variables, in which case an empty list is returned).
+ cols_to_vars[column] = ops.get_collection(
+ ops.GraphKeys.GLOBAL_VARIABLES,
+ scope=variable_scope.get_variable_scope().name)
+ _verify_static_batch_size_equality(output_tensors, ordered_columns)
+ return array_ops.concat(output_tensors, 1)
+
+
+def input_layer(features,
+ feature_columns,
+ weight_collections=None,
+ trainable=True,
+ cols_to_vars=None):
+ """Returns a dense `Tensor` as input layer based on given `feature_columns`.
+
+ Generally a single example in training data is described with FeatureColumns.
+ At the first layer of the model, this column oriented data should be converted
+ to a single `Tensor`.
+
+ Example:
+
+ ```python
+ price = numeric_column('price')
+ keywords_embedded = embedding_column(
+ categorical_column_with_hash_bucket("keywords", 10K), dimensions=16)
+ columns = [price, keywords_embedded, ...]
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ dense_tensor = input_layer(features, columns)
+ for units in [128, 64, 32]:
+ dense_tensor = tf.layers.dense(dense_tensor, units, tf.nn.relu)
+ prediction = tf.layers.dense(dense_tensor, 1)
+ ```
+
+ Args:
+ features: A mapping from key to tensors. `_FeatureColumn`s look up via these
+ keys. For example `numeric_column('price')` will look at 'price' key in
+ this dict. Values can be a `SparseTensor` or a `Tensor` depends on
+ corresponding `_FeatureColumn`.
+ feature_columns: An iterable containing the FeatureColumns to use as inputs
+ to your model. All items should be instances of classes derived from
+ `_DenseColumn` such as `numeric_column`, `embedding_column`,
+ `bucketized_column`, `indicator_column`. If you have categorical features,
+ you can wrap them with an `embedding_column` or `indicator_column`.
+ weight_collections: A list of collection names to which the Variable will be
+ added. Note that variables will also be added to collections
+ `tf.GraphKeys.GLOBAL_VARIABLES` and `ops.GraphKeys.MODEL_VARIABLES`.
+ trainable: If `True` also add the variable to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ cols_to_vars: If not `None`, must be a dictionary that will be filled with a
+ mapping from `_FeatureColumn` to list of `Variable`s. For example, after
+ the call, we might have cols_to_vars =
+ {_EmbeddingColumn(
+ categorical_column=_HashedCategoricalColumn(
+ key='sparse_feature', hash_bucket_size=5, dtype=tf.string),
+ dimension=10): [<tf.Variable 'some_variable:0' shape=(5, 10),
+ <tf.Variable 'some_variable:1' shape=(5, 10)]}
+ If a column creates no variables, its value will be an empty list.
+
+ Returns:
+ A `Tensor` which represents input layer of a model. Its shape
+ is (batch_size, first_layer_dimension) and its dtype is `float32`.
+ first_layer_dimension is determined based on given `feature_columns`.
+
+ Raises:
+ ValueError: if an item in `feature_columns` is not a `_DenseColumn`.
+ """
+ return _internal_input_layer(features, feature_columns, weight_collections,
+ trainable, cols_to_vars)
+
+
+# TODO(akshayka): InputLayer should be a subclass of Layer, and it
+# should implement the logic in input_layer using Layer's build-and-call
+# paradigm; input_layer should create an instance of InputLayer and
+# return the result of invoking its apply method, just as functional layers do.
+class InputLayer(object):
+ """An object-oriented version of `input_layer` that reuses variables."""
+
+ def __init__(self,
+ feature_columns,
+ weight_collections=None,
+ trainable=True,
+ cols_to_vars=None):
+ """See `input_layer`."""
+
+ self._feature_columns = feature_columns
+ self._weight_collections = weight_collections
+ self._trainable = trainable
+ self._cols_to_vars = cols_to_vars
+ self._input_layer_template = template.make_template(
+ 'feature_column_input_layer',
+ _internal_input_layer,
+ create_scope_now_=True)
+ self._scope = self._input_layer_template.variable_scope
+
+ def __call__(self, features):
+ return self._input_layer_template(
+ features=features,
+ feature_columns=self._feature_columns,
+ weight_collections=self._weight_collections,
+ trainable=self._trainable,
+ cols_to_vars=None,
+ scope=self._scope)
+
+ @property
+ def non_trainable_variables(self):
+ return self._input_layer_template.non_trainable_variables
+
+ @property
+ def non_trainable_weights(self):
+ return self._input_layer_template.non_trainable_weights
+
+ @property
+ def trainable_variables(self):
+ return self._input_layer_template.trainable_variables
+
+ @property
+ def trainable_weights(self):
+ return self._input_layer_template.trainable_weights
+
+ @property
+ def variables(self):
+ return self._input_layer_template.variables
+
+ @property
+ def weights(self):
+ return self._input_layer_template.weights
+
+
+def linear_model(features,
+ feature_columns,
+ units=1,
+ sparse_combiner='sum',
+ weight_collections=None,
+ trainable=True,
+ cols_to_vars=None):
+ """Returns a linear prediction `Tensor` based on given `feature_columns`.
+
+ This function generates a weighted sum based on output dimension `units`.
+ Weighted sum refers to logits in classification problems. It refers to the
+ prediction itself for linear regression problems.
+
+ Note on supported columns: `linear_model` treats categorical columns as
+ `indicator_column`s. To be specific, assume the input as `SparseTensor` looks
+ like:
+
+ ```python
+ shape = [2, 2]
+ {
+ [0, 0]: "a"
+ [1, 0]: "b"
+ [1, 1]: "c"
+ }
+ ```
+ `linear_model` assigns weights for the presence of "a", "b", "c' implicitly,
+ just like `indicator_column`, while `input_layer` explicitly requires wrapping
+ each of categorical columns with an `embedding_column` or an
+ `indicator_column`.
+
+ Example of usage:
+
+ ```python
+ price = numeric_column('price')
+ price_buckets = bucketized_column(price, boundaries=[0., 10., 100., 1000.])
+ keywords = categorical_column_with_hash_bucket("keywords", 10K)
+ keywords_price = crossed_column('keywords', price_buckets, ...)
+ columns = [price_buckets, keywords, keywords_price ...]
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ prediction = linear_model(features, columns)
+ ```
+
+ Args:
+ features: A mapping from key to tensors. `_FeatureColumn`s look up via these
+ keys. For example `numeric_column('price')` will look at 'price' key in
+ this dict. Values are `Tensor` or `SparseTensor` depending on
+ corresponding `_FeatureColumn`.
+ feature_columns: An iterable containing the FeatureColumns to use as inputs
+ to your model. All items should be instances of classes derived from
+ `_FeatureColumn`s.
+ units: An integer, dimensionality of the output space. Default value is 1.
+ sparse_combiner: A string specifying how to reduce if a categorical column
+ is multivalent. Except `numeric_column`, almost all columns passed to
+ `linear_model` are considered as categorical columns. It combines each
+ categorical column independently. Currently "mean", "sqrtn" and "sum" are
+ supported, with "sum" the default for linear model. "sqrtn" often achieves
+ good accuracy, in particular with bag-of-words columns.
+ * "sum": do not normalize features in the column
+ * "mean": do l1 normalization on features in the column
+ * "sqrtn": do l2 normalization on features in the column
+ For example, for two features represented as the categorical columns:
+
+ ```python
+ # Feature 1
+
+ shape = [2, 2]
+ {
+ [0, 0]: "a"
+ [0, 1]: "b"
+ [1, 0]: "c"
+ }
+
+ # Feature 2
+
+ shape = [2, 3]
+ {
+ [0, 0]: "d"
+ [1, 0]: "e"
+ [1, 1]: "f"
+ [1, 2]: "g"
+ }
+ ```
+ with `sparse_combiner` as "mean", the linear model outputs conceptly are:
+ ```
+ y_0 = 1.0 / 2.0 * ( w_a + w_ b) + w_c + b_0
+ y_1 = w_d + 1.0 / 3.0 * ( w_e + w_ f + w_g) + b_1
+ ```
+ where `y_i` is the output, `b_i` is the bias, and `w_x` is the weight
+ assigned to the presence of `x` in the input features.
+ weight_collections: A list of collection names to which the Variable will be
+ added. Note that, variables will also be added to collections
+ `tf.GraphKeys.GLOBAL_VARIABLES` and `ops.GraphKeys.MODEL_VARIABLES`.
+ trainable: If `True` also add the variable to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ cols_to_vars: If not `None`, must be a dictionary that will be filled with a
+ mapping from `_FeatureColumn` to associated list of `Variable`s. For
+ example, after the call, we might have cols_to_vars = {
+ _NumericColumn(
+ key='numeric_feature1', shape=(1,):
+ [<tf.Variable 'linear_model/price2/weights:0' shape=(1, 1)>],
+ 'bias': [<tf.Variable 'linear_model/bias_weights:0' shape=(1,)>],
+ _NumericColumn(
+ key='numeric_feature2', shape=(2,)):
+ [<tf.Variable 'linear_model/price1/weights:0' shape=(2, 1)>]}
+ If a column creates no variables, its value will be an empty list. Note
+ that cols_to_vars will also contain a string key 'bias' that maps to a
+ list of Variables.
+
+ Returns:
+ A `Tensor` which represents predictions/logits of a linear model. Its shape
+ is (batch_size, units) and its dtype is `float32`.
+
+ Raises:
+ ValueError: if an item in `feature_columns` is neither a `_DenseColumn`
+ nor `_CategoricalColumn`.
+ """
+ with variable_scope.variable_scope(None, 'linear_model') as vs:
+ model_name = _strip_leading_slashes(vs.name)
+ linear_model_layer = _LinearModel(
+ feature_columns=feature_columns,
+ units=units,
+ sparse_combiner=sparse_combiner,
+ weight_collections=weight_collections,
+ trainable=trainable,
+ name=model_name)
+ retval = linear_model_layer(features) # pylint: disable=not-callable
+ if cols_to_vars is not None:
+ cols_to_vars.update(linear_model_layer.cols_to_vars())
+ return retval
+
+
+def _add_to_collections(var, weight_collections):
+ """Adds a var to the list of weight_collections provided.
+
+ Handles the case for partitioned and non-partitioned variables.
+
+ Args:
+ var: A variable or Partitioned Variable.
+ weight_collections: List of collections to add variable to.
+ """
+ for weight_collection in weight_collections:
+ # The layer self.add_variable call already adds it to GLOBAL_VARIABLES.
+ if weight_collection == ops.GraphKeys.GLOBAL_VARIABLES:
+ continue
+ # TODO(rohanj): Explore adding a _get_variable_list method on `Variable`
+ # so that we don't have to do this check.
+ if isinstance(var, variables.PartitionedVariable):
+ for constituent_var in list(var):
+ ops.add_to_collection(weight_collection, constituent_var)
+ else:
+ ops.add_to_collection(weight_collection, var)
+
+
+class _FCLinearWrapper(base.Layer):
+ """Wraps a _FeatureColumn in a layer for use in a linear model.
+
+ See `linear_model` above.
+ """
+
+ def __init__(self,
+ feature_column,
+ units=1,
+ sparse_combiner='sum',
+ weight_collections=None,
+ trainable=True,
+ name=None,
+ **kwargs):
+ super(_FCLinearWrapper, self).__init__(
+ trainable=trainable, name=name, **kwargs)
+ self._feature_column = feature_column
+ self._units = units
+ self._sparse_combiner = sparse_combiner
+ self._weight_collections = weight_collections
+
+ def build(self, _):
+ if isinstance(self._feature_column, fc_old._CategoricalColumn): # pylint: disable=protected-access
+ weight = self.add_variable(
+ name='weights',
+ shape=(self._feature_column._num_buckets, self._units), # pylint: disable=protected-access
+ initializer=init_ops.zeros_initializer(),
+ trainable=self.trainable)
+ else:
+ num_elements = self._feature_column._variable_shape.num_elements() # pylint: disable=protected-access
+ weight = self.add_variable(
+ name='weights',
+ shape=[num_elements, self._units],
+ initializer=init_ops.zeros_initializer(),
+ trainable=self.trainable)
+ _add_to_collections(weight, self._weight_collections)
+ self._weight_var = weight
+ self.built = True
+
+ def call(self, builder):
+ weighted_sum = fc_old._create_weighted_sum( # pylint: disable=protected-access
+ column=self._feature_column,
+ builder=builder,
+ units=self._units,
+ sparse_combiner=self._sparse_combiner,
+ weight_collections=self._weight_collections,
+ trainable=self.trainable,
+ weight_var=self._weight_var)
+ return weighted_sum
+
+
+class _BiasLayer(base.Layer):
+ """A layer for the bias term.
+ """
+
+ def __init__(self,
+ units=1,
+ trainable=True,
+ weight_collections=None,
+ name=None,
+ **kwargs):
+ super(_BiasLayer, self).__init__(trainable=trainable, name=name, **kwargs)
+ self._units = units
+ self._weight_collections = weight_collections
+
+ def build(self, _):
+ self._bias_variable = self.add_variable(
+ 'bias_weights',
+ shape=[self._units],
+ initializer=init_ops.zeros_initializer(),
+ trainable=self.trainable)
+ _add_to_collections(self._bias_variable, self._weight_collections)
+ self.built = True
+
+ def call(self, _):
+ return self._bias_variable
+
+
+def _get_expanded_variable_list(variable):
+ if (isinstance(variable, variables.Variable) or
+ resource_variable_ops.is_resource_variable(variable)):
+ return [variable] # Single variable case.
+ else: # Must be a PartitionedVariable, so convert into a list.
+ return list(variable)
+
+
+def _strip_leading_slashes(name):
+ return name.rsplit('/', 1)[-1]
+
+
+class _LinearModel(training.Model):
+ """Creates a linear model using feature columns.
+
+ See `linear_model` for details.
+ """
+
+ def __init__(self,
+ feature_columns,
+ units=1,
+ sparse_combiner='sum',
+ weight_collections=None,
+ trainable=True,
+ name=None,
+ **kwargs):
+ super(_LinearModel, self).__init__(name=name, **kwargs)
+ self._feature_columns = fc_old._normalize_feature_columns( # pylint: disable=protected-access
+ feature_columns)
+ self._weight_collections = list(weight_collections or [])
+ if ops.GraphKeys.GLOBAL_VARIABLES not in self._weight_collections:
+ self._weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES)
+ if ops.GraphKeys.MODEL_VARIABLES not in self._weight_collections:
+ self._weight_collections.append(ops.GraphKeys.MODEL_VARIABLES)
+
+ column_layers = {}
+ for column in sorted(self._feature_columns, key=lambda x: x.name):
+ with variable_scope.variable_scope(
+ None, default_name=column._var_scope_name) as vs: # pylint: disable=protected-access
+ # Having the fully expressed variable scope name ends up doubly
+ # expressing the outer scope (scope with which this method was called)
+ # in the name of the variable that would get created.
+ column_name = _strip_leading_slashes(vs.name)
+ column_layer = _FCLinearWrapper(column, units, sparse_combiner,
+ self._weight_collections, trainable,
+ column_name, **kwargs)
+ column_layers[column_name] = column_layer
+ self._column_layers = self._add_layers(column_layers)
+ self._bias_layer = _BiasLayer(
+ units=units,
+ trainable=trainable,
+ weight_collections=self._weight_collections,
+ name='bias_layer',
+ **kwargs)
+ self._cols_to_vars = {}
+
+ def cols_to_vars(self):
+ """Returns a dict mapping _FeatureColumns to variables.
+
+ See `linear_model` for more information.
+ This is not populated till `call` is called i.e. layer is built.
+ """
+ return self._cols_to_vars
+
+ def call(self, features):
+ with variable_scope.variable_scope(self.name):
+ for column in self._feature_columns:
+ if not isinstance(
+ column,
+ (
+ fc_old._DenseColumn, # pylint: disable=protected-access
+ fc_old._CategoricalColumn)): # pylint: disable=protected-access
+ raise ValueError(
+ 'Items of feature_columns must be either a '
+ '_DenseColumn or _CategoricalColumn. Given: {}'.format(column))
+ weighted_sums = []
+ ordered_columns = []
+ builder = fc_old._LazyBuilder(features) # pylint: disable=protected-access
+ for layer in sorted(self._column_layers.values(), key=lambda x: x.name):
+ column = layer._feature_column # pylint: disable=protected-access
+ ordered_columns.append(column)
+ weighted_sum = layer(builder)
+ weighted_sums.append(weighted_sum)
+ self._cols_to_vars[column] = ops.get_collection(
+ ops.GraphKeys.GLOBAL_VARIABLES, scope=layer.scope_name)
+
+ _verify_static_batch_size_equality(weighted_sums, ordered_columns)
+ predictions_no_bias = math_ops.add_n(
+ weighted_sums, name='weighted_sum_no_bias')
+ predictions = nn_ops.bias_add(
+ predictions_no_bias,
+ self._bias_layer( # pylint: disable=not-callable
+ builder,
+ scope=variable_scope.get_variable_scope()), # pylint: disable=not-callable
+ name='weighted_sum')
+ bias = self._bias_layer.variables[0]
+ self._cols_to_vars['bias'] = _get_expanded_variable_list(bias)
+ return predictions
+
+ def _add_layers(self, layers):
+ # "Magic" required for keras.Model classes to track all the variables in
+ # a list of layers.Layer objects.
+ # TODO(ashankar): Figure out API so user code doesn't have to do this.
+ for name, layer in layers.items():
+ setattr(self, 'layer-%s' % name, layer)
+ return layers
+
+
+def _transform_features(features, feature_columns, state_manager):
+ """Returns transformed features based on features columns passed in.
+
+ Please note that most probably you would not need to use this function. Please
+ check `input_layer` and `linear_model` to see whether they will
+ satisfy your use case or not.
+
+ Example:
+
+ ```python
+ # Define features and transformations
+ crosses_a_x_b = crossed_column(
+ columns=["sparse_feature_a", "sparse_feature_b"], hash_bucket_size=10000)
+ price_buckets = bucketized_column(
+ source_column=numeric_column("price"), boundaries=[...])
+
+ columns = [crosses_a_x_b, price_buckets]
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ transformed = transform_features(features=features, feature_columns=columns)
+
+ assertCountEqual(columns, transformed.keys())
+ ```
+
+ Args:
+ features: A mapping from key to tensors. `FeatureColumn`s look up via these
+ keys. For example `numeric_column('price')` will look at 'price' key in
+ this dict. Values can be a `SparseTensor` or a `Tensor` depends on
+ corresponding `FeatureColumn`.
+ feature_columns: An iterable containing all the `FeatureColumn`s.
+ state_manager: A StateManager object that holds the FeatureColumn state.
+
+ Returns:
+ A `dict` mapping `FeatureColumn` to `Tensor` and `SparseTensor` values.
+ """
+ feature_columns = _normalize_feature_columns(feature_columns)
+ outputs = {}
+ with ops.name_scope(
+ None, default_name='transform_features', values=features.values()):
+ transformation_cache = FeatureTransformationCache(features)
+ for column in sorted(feature_columns, key=lambda x: x.name):
+ with ops.name_scope(None, default_name=column.name):
+ outputs[column] = transformation_cache.get(column, state_manager)
+ return outputs
+
+
+def make_parse_example_spec(feature_columns):
+ """Creates parsing spec dictionary from input feature_columns.
+
+ The returned dictionary can be used as arg 'features' in `tf.parse_example`.
+
+ Typical usage example:
+
+ ```python
+ # Define features and transformations
+ feature_a = categorical_column_with_vocabulary_file(...)
+ feature_b = numeric_column(...)
+ feature_c_bucketized = bucketized_column(numeric_column("feature_c"), ...)
+ feature_a_x_feature_c = crossed_column(
+ columns=["feature_a", feature_c_bucketized], ...)
+
+ feature_columns = set(
+ [feature_b, feature_c_bucketized, feature_a_x_feature_c])
+ features = tf.parse_example(
+ serialized=serialized_examples,
+ features=make_parse_example_spec(feature_columns))
+ ```
+
+ For the above example, make_parse_example_spec would return the dict:
+
+ ```python
+ {
+ "feature_a": parsing_ops.VarLenFeature(tf.string),
+ "feature_b": parsing_ops.FixedLenFeature([1], dtype=tf.float32),
+ "feature_c": parsing_ops.FixedLenFeature([1], dtype=tf.float32)
+ }
+ ```
+
+ Args:
+ feature_columns: An iterable containing all feature columns. All items
+ should be instances of classes derived from `FeatureColumn`.
+
+ Returns:
+ A dict mapping each feature key to a `FixedLenFeature` or `VarLenFeature`
+ value.
+
+ Raises:
+ ValueError: If any of the given `feature_columns` is not a `FeatureColumn`
+ instance.
+ """
+ result = {}
+ for column in feature_columns:
+ if not isinstance(column, FeatureColumn):
+ raise ValueError('All feature_columns must be FeatureColumn instances. '
+ 'Given: {}'.format(column))
+ config = column.parse_example_spec
+ for key, value in six.iteritems(config):
+ if key in result and value != result[key]:
+ raise ValueError(
+ 'feature_columns contain different parse_spec for key '
+ '{}. Given {} and {}'.format(key, value, result[key]))
+ result.update(config)
+ return result
+
+
+def embedding_column(
+ categorical_column, dimension, combiner='mean', initializer=None,
+ ckpt_to_load_from=None, tensor_name_in_ckpt=None, max_norm=None,
+ trainable=True):
+ """`_DenseColumn` that converts from sparse, categorical input.
+
+ Use this when your inputs are sparse, but you want to convert them to a dense
+ representation (e.g., to feed to a DNN).
+
+ Inputs must be a `_CategoricalColumn` created by any of the
+ `categorical_column_*` function. Here is an example of using
+ `embedding_column` with `DNNClassifier`:
+
+ ```python
+ video_id = categorical_column_with_identity(
+ key='video_id', num_buckets=1000000, default_value=0)
+ columns = [embedding_column(video_id, 9),...]
+
+ estimator = tf.estimator.DNNClassifier(feature_columns=columns, ...)
+
+ label_column = ...
+ def input_fn():
+ features = tf.parse_example(
+ ..., features=make_parse_example_spec(columns + [label_column]))
+ labels = features.pop(label_column.name)
+ return features, labels
+
+ estimator.train(input_fn=input_fn, steps=100)
+ ```
+
+ Here is an example using `embedding_column` with model_fn:
+
+ ```python
+ def model_fn(features, ...):
+ video_id = categorical_column_with_identity(
+ key='video_id', num_buckets=1000000, default_value=0)
+ columns = [embedding_column(video_id, 9),...]
+ dense_tensor = input_layer(features, columns)
+ # Form DNN layers, calculate loss, and return EstimatorSpec.
+ ...
+ ```
+
+ Args:
+ categorical_column: A `_CategoricalColumn` created by a
+ `categorical_column_with_*` function. This column produces the sparse IDs
+ that are inputs to the embedding lookup.
+ dimension: An integer specifying dimension of the embedding, must be > 0.
+ combiner: A string specifying how to reduce if there are multiple entries
+ in a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with
+ 'mean' the default. 'sqrtn' often achieves good accuracy, in particular
+ with bag-of-words columns. Each of this can be thought as example level
+ normalizations on the column. For more information, see
+ `tf.embedding_lookup_sparse`.
+ initializer: A variable initializer function to be used in embedding
+ variable initialization. If not specified, defaults to
+ `tf.truncated_normal_initializer` with mean `0.0` and standard deviation
+ `1/sqrt(dimension)`.
+ ckpt_to_load_from: String representing checkpoint name/pattern from which to
+ restore column weights. Required if `tensor_name_in_ckpt` is not `None`.
+ tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from
+ which to restore the column weights. Required if `ckpt_to_load_from` is
+ not `None`.
+ max_norm: If not `None`, embedding values are l2-normalized to this value.
+ trainable: Whether or not the embedding is trainable. Default is True.
+
+ Returns:
+ `_DenseColumn` that converts from sparse input.
+
+ Raises:
+ ValueError: if `dimension` not > 0.
+ ValueError: if exactly one of `ckpt_to_load_from` and `tensor_name_in_ckpt`
+ is specified.
+ ValueError: if `initializer` is specified and is not callable.
+ RuntimeError: If eager execution is enabled.
+ """
+ if (dimension is None) or (dimension < 1):
+ raise ValueError('Invalid dimension {}.'.format(dimension))
+ if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None):
+ raise ValueError('Must specify both `ckpt_to_load_from` and '
+ '`tensor_name_in_ckpt` or none of them.')
+
+ if (initializer is not None) and (not callable(initializer)):
+ raise ValueError('initializer must be callable if specified. '
+ 'Embedding of column_name: {}'.format(
+ categorical_column.name))
+ if initializer is None:
+ initializer = init_ops.truncated_normal_initializer(
+ mean=0.0, stddev=1 / math.sqrt(dimension))
+
+ return EmbeddingColumn(
+ categorical_column=categorical_column,
+ dimension=dimension,
+ combiner=combiner,
+ initializer=initializer,
+ ckpt_to_load_from=ckpt_to_load_from,
+ tensor_name_in_ckpt=tensor_name_in_ckpt,
+ max_norm=max_norm,
+ trainable=trainable)
+
+
+def shared_embedding_columns(
+ categorical_columns, dimension, combiner='mean', initializer=None,
+ shared_embedding_collection_name=None, ckpt_to_load_from=None,
+ tensor_name_in_ckpt=None, max_norm=None, trainable=True):
+ """List of dense columns that convert from sparse, categorical input.
+
+ This is similar to `embedding_column`, except that it produces a list of
+ embedding columns that share the same embedding weights.
+
+ Use this when your inputs are sparse and of the same type (e.g. watched and
+ impression video IDs that share the same vocabulary), and you want to convert
+ them to a dense representation (e.g., to feed to a DNN).
+
+ Inputs must be a list of categorical columns created by any of the
+ `categorical_column_*` function. They must all be of the same type and have
+ the same arguments except `key`. E.g. they can be
+ categorical_column_with_vocabulary_file with the same vocabulary_file. Some or
+ all columns could also be weighted_categorical_column.
+
+ Here is an example embedding of two features for a DNNClassifier model:
+
+ ```python
+ watched_video_id = categorical_column_with_vocabulary_file(
+ 'watched_video_id', video_vocabulary_file, video_vocabulary_size)
+ impression_video_id = categorical_column_with_vocabulary_file(
+ 'impression_video_id', video_vocabulary_file, video_vocabulary_size)
+ columns = shared_embedding_columns(
+ [watched_video_id, impression_video_id], dimension=10)
+
+ estimator = tf.estimator.DNNClassifier(feature_columns=columns, ...)
+
+ label_column = ...
+ def input_fn():
+ features = tf.parse_example(
+ ..., features=make_parse_example_spec(columns + [label_column]))
+ labels = features.pop(label_column.name)
+ return features, labels
+
+ estimator.train(input_fn=input_fn, steps=100)
+ ```
+
+ Here is an example using `shared_embedding_columns` with model_fn:
+
+ ```python
+ def model_fn(features, ...):
+ watched_video_id = categorical_column_with_vocabulary_file(
+ 'watched_video_id', video_vocabulary_file, video_vocabulary_size)
+ impression_video_id = categorical_column_with_vocabulary_file(
+ 'impression_video_id', video_vocabulary_file, video_vocabulary_size)
+ columns = shared_embedding_columns(
+ [watched_video_id, impression_video_id], dimension=10)
+ dense_tensor = input_layer(features, columns)
+ # Form DNN layers, calculate loss, and return EstimatorSpec.
+ ...
+ ```
+
+ Args:
+ categorical_columns: List of categorical columns created by a
+ `categorical_column_with_*` function. These columns produce the sparse IDs
+ that are inputs to the embedding lookup. All columns must be of the same
+ type and have the same arguments except `key`. E.g. they can be
+ categorical_column_with_vocabulary_file with the same vocabulary_file.
+ Some or all columns could also be weighted_categorical_column.
+ dimension: An integer specifying dimension of the embedding, must be > 0.
+ combiner: A string specifying how to reduce if there are multiple entries
+ in a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with
+ 'mean' the default. 'sqrtn' often achieves good accuracy, in particular
+ with bag-of-words columns. Each of this can be thought as example level
+ normalizations on the column. For more information, see
+ `tf.embedding_lookup_sparse`.
+ initializer: A variable initializer function to be used in embedding
+ variable initialization. If not specified, defaults to
+ `tf.truncated_normal_initializer` with mean `0.0` and standard deviation
+ `1/sqrt(dimension)`.
+ shared_embedding_collection_name: Optional collective name of these columns.
+ If not given, a reasonable name will be chosen based on the names of
+ `categorical_columns`.
+ ckpt_to_load_from: String representing checkpoint name/pattern from which to
+ restore column weights. Required if `tensor_name_in_ckpt` is not `None`.
+ tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from
+ which to restore the column weights. Required if `ckpt_to_load_from` is
+ not `None`.
+ max_norm: If not `None`, each embedding is clipped if its l2-norm is
+ larger than this value, before combining.
+ trainable: Whether or not the embedding is trainable. Default is True.
+
+ Returns:
+ A list of dense columns that converts from sparse input. The order of
+ results follows the ordering of `categorical_columns`.
+
+ Raises:
+ ValueError: if `dimension` not > 0.
+ ValueError: if any of the given `categorical_columns` is of different type
+ or has different arguments than the others.
+ ValueError: if exactly one of `ckpt_to_load_from` and `tensor_name_in_ckpt`
+ is specified.
+ ValueError: if `initializer` is specified and is not callable.
+ RuntimeError: if eager execution is enabled.
+ """
+ if context.executing_eagerly():
+ raise RuntimeError('shared_embedding_columns are not supported when eager '
+ 'execution is enabled.')
+
+ if (dimension is None) or (dimension < 1):
+ raise ValueError('Invalid dimension {}.'.format(dimension))
+ if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None):
+ raise ValueError('Must specify both `ckpt_to_load_from` and '
+ '`tensor_name_in_ckpt` or none of them.')
+
+ if (initializer is not None) and (not callable(initializer)):
+ raise ValueError('initializer must be callable if specified.')
+ if initializer is None:
+ initializer = init_ops.truncated_normal_initializer(
+ mean=0.0, stddev=1. / math.sqrt(dimension))
+
+ # Sort the columns so the default collection name is deterministic even if the
+ # user passes columns from an unsorted collection, such as dict.values().
+ sorted_columns = sorted(categorical_columns, key=lambda x: x.name)
+
+ c0 = sorted_columns[0]
+ num_buckets = c0.num_buckets
+ if not isinstance(c0, CategoricalColumn):
+ raise ValueError(
+ 'All categorical_columns must be subclasses of CategoricalColumn. '
+ 'Given: {}, of type: {}'.format(c0, type(c0)))
+ if isinstance(c0, WeightedCategoricalColumn):
+ c0 = c0.categorical_column
+ for c in sorted_columns[1:]:
+ if isinstance(c, WeightedCategoricalColumn):
+ c = c.categorical_column
+ if not isinstance(c, type(c0)):
+ raise ValueError(
+ 'To use shared_embedding_column, all categorical_columns must have '
+ 'the same type, or be weighted_categorical_column of the same type. '
+ 'Given column: {} of type: {} does not match given column: {} of '
+ 'type: {}'.format(c0, type(c0), c, type(c)))
+ if num_buckets != c.num_buckets:
+ raise ValueError(
+ 'To use shared_embedding_column, all categorical_columns must have '
+ 'the same number of buckets. Given column: {} with buckets: {} does '
+ 'not match column: {} with buckets: {}'.format(
+ c0, num_buckets, c, c.num_buckets))
+
+ if not shared_embedding_collection_name:
+ shared_embedding_collection_name = '_'.join(c.name for c in sorted_columns)
+ shared_embedding_collection_name += '_shared_embedding'
+
+ result = []
+ for column in categorical_columns:
+ result.append(
+ SharedEmbeddingColumn(
+ categorical_column=column,
+ initializer=initializer,
+ dimension=dimension,
+ combiner=combiner,
+ shared_embedding_collection_name=shared_embedding_collection_name,
+ ckpt_to_load_from=ckpt_to_load_from,
+ tensor_name_in_ckpt=tensor_name_in_ckpt,
+ max_norm=max_norm,
+ trainable=trainable))
+
+ return result
+
+
+def numeric_column(key,
+ shape=(1,),
+ default_value=None,
+ dtype=dtypes.float32,
+ normalizer_fn=None):
+ """Represents real valued or numerical features.
+
+ Example:
+
+ ```python
+ price = numeric_column('price')
+ columns = [price, ...]
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ dense_tensor = input_layer(features, columns)
+
+ # or
+ bucketized_price = bucketized_column(price, boundaries=[...])
+ columns = [bucketized_price, ...]
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ linear_prediction = linear_model(features, columns)
+ ```
+
+ Args:
+ key: A unique string identifying the input feature. It is used as the
+ column name and the dictionary key for feature parsing configs, feature
+ `Tensor` objects, and feature columns.
+ shape: An iterable of integers specifies the shape of the `Tensor`. An
+ integer can be given which means a single dimension `Tensor` with given
+ width. The `Tensor` representing the column will have the shape of
+ [batch_size] + `shape`.
+ default_value: A single value compatible with `dtype` or an iterable of
+ values compatible with `dtype` which the column takes on during
+ `tf.Example` parsing if data is missing. A default value of `None` will
+ cause `tf.parse_example` to fail if an example does not contain this
+ column. If a single value is provided, the same value will be applied as
+ the default value for every item. If an iterable of values is provided,
+ the shape of the `default_value` should be equal to the given `shape`.
+ dtype: defines the type of values. Default value is `tf.float32`. Must be a
+ non-quantized, real integer or floating point type.
+ normalizer_fn: If not `None`, a function that can be used to normalize the
+ value of the tensor after `default_value` is applied for parsing.
+ Normalizer function takes the input `Tensor` as its argument, and returns
+ the output `Tensor`. (e.g. lambda x: (x - 3.0) / 4.2). Please note that
+ even though the most common use case of this function is normalization, it
+ can be used for any kind of Tensorflow transformations.
+
+ Returns:
+ A `NumericColumn`.
+
+ Raises:
+ TypeError: if any dimension in shape is not an int
+ ValueError: if any dimension in shape is not a positive integer
+ TypeError: if `default_value` is an iterable but not compatible with `shape`
+ TypeError: if `default_value` is not compatible with `dtype`.
+ ValueError: if `dtype` is not convertible to `tf.float32`.
+ """
+ shape = _check_shape(shape, key)
+ if not (dtype.is_integer or dtype.is_floating):
+ raise ValueError('dtype must be convertible to float. '
+ 'dtype: {}, key: {}'.format(dtype, key))
+ default_value = _check_default_value(shape, default_value, dtype, key)
+
+ if normalizer_fn is not None and not callable(normalizer_fn):
+ raise TypeError(
+ 'normalizer_fn must be a callable. Given: {}'.format(normalizer_fn))
+
+ _assert_key_is_string(key)
+ return NumericColumn(
+ key,
+ shape=shape,
+ default_value=default_value,
+ dtype=dtype,
+ normalizer_fn=normalizer_fn)
+
+
+def bucketized_column(source_column, boundaries):
+ """Represents discretized dense input.
+
+ Buckets include the left boundary, and exclude the right boundary. Namely,
+ `boundaries=[0., 1., 2.]` generates buckets `(-inf, 0.)`, `[0., 1.)`,
+ `[1., 2.)`, and `[2., +inf)`.
+
+ For example, if the inputs are
+
+ ```python
+ boundaries = [0, 10, 100]
+ input tensor = [[-5, 10000]
+ [150, 10]
+ [5, 100]]
+ ```
+
+ then the output will be
+
+ ```python
+ output = [[0, 3]
+ [3, 2]
+ [1, 3]]
+ ```
+
+ Example:
+
+ ```python
+ price = numeric_column('price')
+ bucketized_price = bucketized_column(price, boundaries=[...])
+ columns = [bucketized_price, ...]
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ linear_prediction = linear_model(features, columns)
+
+ # or
+ columns = [bucketized_price, ...]
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ dense_tensor = input_layer(features, columns)
+ ```
+
+ `bucketized_column` can also be crossed with another categorical column using
+ `crossed_column`:
+
+ ```python
+ price = numeric_column('price')
+ # bucketized_column converts numerical feature to a categorical one.
+ bucketized_price = bucketized_column(price, boundaries=[...])
+ # 'keywords' is a string feature.
+ price_x_keywords = crossed_column([bucketized_price, 'keywords'], 50K)
+ columns = [price_x_keywords, ...]
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ linear_prediction = linear_model(features, columns)
+ ```
+
+ Args:
+ source_column: A one-dimensional dense column which is generated with
+ `numeric_column`.
+ boundaries: A sorted list or tuple of floats specifying the boundaries.
+
+ Returns:
+ A `BucketizedColumn`.
+
+ Raises:
+ ValueError: If `source_column` is not a numeric column, or if it is not
+ one-dimensional.
+ ValueError: If `boundaries` is not a sorted list or tuple.
+ """
+ if not isinstance(source_column, NumericColumn):
+ raise ValueError(
+ 'source_column must be a column generated with numeric_column(). '
+ 'Given: {}'.format(source_column))
+ if len(source_column.shape) > 1:
+ raise ValueError(
+ 'source_column must be one-dimensional column. '
+ 'Given: {}'.format(source_column))
+ if (not boundaries or
+ not (isinstance(boundaries, list) or isinstance(boundaries, tuple))):
+ raise ValueError('boundaries must be a sorted list.')
+ for i in range(len(boundaries) - 1):
+ if boundaries[i] >= boundaries[i + 1]:
+ raise ValueError('boundaries must be a sorted list.')
+ return BucketizedColumn(source_column, tuple(boundaries))
+
+
+def _assert_string_or_int(dtype, prefix):
+ if (dtype != dtypes.string) and (not dtype.is_integer):
+ raise ValueError(
+ '{} dtype must be string or integer. dtype: {}.'.format(prefix, dtype))
+
+
+def _assert_key_is_string(key):
+ if not isinstance(key, six.string_types):
+ raise ValueError(
+ 'key must be a string. Got: type {}. Given key: {}.'.format(
+ type(key), key))
+
+
+def categorical_column_with_hash_bucket(key,
+ hash_bucket_size,
+ dtype=dtypes.string):
+ """Represents sparse feature where ids are set by hashing.
+
+ Use this when your sparse features are in string or integer format, and you
+ want to distribute your inputs into a finite number of buckets by hashing.
+ output_id = Hash(input_feature_string) % bucket_size for string type input.
+ For int type input, the value is converted to its string representation first
+ and then hashed by the same formula.
+
+ For input dictionary `features`, `features[key]` is either `Tensor` or
+ `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
+ and `''` for string, which will be dropped by this feature column.
+
+ Example:
+
+ ```python
+ keywords = categorical_column_with_hash_bucket("keywords", 10K)
+ columns = [keywords, ...]
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ linear_prediction = linear_model(features, columns)
+
+ # or
+ keywords_embedded = embedding_column(keywords, 16)
+ columns = [keywords_embedded, ...]
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ dense_tensor = input_layer(features, columns)
+ ```
+
+ Args:
+ key: A unique string identifying the input feature. It is used as the
+ column name and the dictionary key for feature parsing configs, feature
+ `Tensor` objects, and feature columns.
+ hash_bucket_size: An int > 1. The number of buckets.
+ dtype: The type of features. Only string and integer types are supported.
+
+ Returns:
+ A `HashedCategoricalColumn`.
+
+ Raises:
+ ValueError: `hash_bucket_size` is not greater than 1.
+ ValueError: `dtype` is neither string nor integer.
+ """
+ if hash_bucket_size is None:
+ raise ValueError('hash_bucket_size must be set. ' 'key: {}'.format(key))
+
+ if hash_bucket_size < 1:
+ raise ValueError('hash_bucket_size must be at least 1. '
+ 'hash_bucket_size: {}, key: {}'.format(
+ hash_bucket_size, key))
+
+ _assert_key_is_string(key)
+ _assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
+
+ return HashedCategoricalColumn(key, hash_bucket_size, dtype)
+
+
+def categorical_column_with_vocabulary_file(key,
+ vocabulary_file,
+ vocabulary_size=None,
+ num_oov_buckets=0,
+ default_value=None,
+ dtype=dtypes.string):
+ """A `CategoricalColumn` with a vocabulary file.
+
+ Use this when your inputs are in string or integer format, and you have a
+ vocabulary file that maps each value to an integer ID. By default,
+ out-of-vocabulary values are ignored. Use either (but not both) of
+ `num_oov_buckets` and `default_value` to specify how to include
+ out-of-vocabulary values.
+
+ For input dictionary `features`, `features[key]` is either `Tensor` or
+ `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
+ and `''` for string, which will be dropped by this feature column.
+
+ Example with `num_oov_buckets`:
+ File '/us/states.txt' contains 50 lines, each with a 2-character U.S. state
+ abbreviation. All inputs with values in that file are assigned an ID 0-49,
+ corresponding to its line number. All other values are hashed and assigned an
+ ID 50-54.
+
+ ```python
+ states = categorical_column_with_vocabulary_file(
+ key='states', vocabulary_file='/us/states.txt', vocabulary_size=50,
+ num_oov_buckets=5)
+ columns = [states, ...]
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ linear_prediction = linear_model(features, columns)
+ ```
+
+ Example with `default_value`:
+ File '/us/states.txt' contains 51 lines - the first line is 'XX', and the
+ other 50 each have a 2-character U.S. state abbreviation. Both a literal 'XX'
+ in input, and other values missing from the file, will be assigned ID 0. All
+ others are assigned the corresponding line number 1-50.
+
+ ```python
+ states = categorical_column_with_vocabulary_file(
+ key='states', vocabulary_file='/us/states.txt', vocabulary_size=51,
+ default_value=0)
+ columns = [states, ...]
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ linear_prediction, _, _ = linear_model(features, columns)
+ ```
+
+ And to make an embedding with either:
+
+ ```python
+ columns = [embedding_column(states, 3),...]
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ dense_tensor = input_layer(features, columns)
+ ```
+
+ Args:
+ key: A unique string identifying the input feature. It is used as the
+ column name and the dictionary key for feature parsing configs, feature
+ `Tensor` objects, and feature columns.
+ vocabulary_file: The vocabulary file name.
+ vocabulary_size: Number of the elements in the vocabulary. This must be no
+ greater than length of `vocabulary_file`, if less than length, later
+ values are ignored. If None, it is set to the length of `vocabulary_file`.
+ num_oov_buckets: Non-negative integer, the number of out-of-vocabulary
+ buckets. All out-of-vocabulary inputs will be assigned IDs in the range
+ `[vocabulary_size, vocabulary_size+num_oov_buckets)` based on a hash of
+ the input value. A positive `num_oov_buckets` can not be specified with
+ `default_value`.
+ default_value: The integer ID value to return for out-of-vocabulary feature
+ values, defaults to `-1`. This can not be specified with a positive
+ `num_oov_buckets`.
+ dtype: The type of features. Only string and integer types are supported.
+
+ Returns:
+ A `CategoricalColumn` with a vocabulary file.
+
+ Raises:
+ ValueError: `vocabulary_file` is missing or cannot be opened.
+ ValueError: `vocabulary_size` is missing or < 1.
+ ValueError: `num_oov_buckets` is a negative integer.
+ ValueError: `num_oov_buckets` and `default_value` are both specified.
+ ValueError: `dtype` is neither string nor integer.
+ """
+ if not vocabulary_file:
+ raise ValueError('Missing vocabulary_file in {}.'.format(key))
+
+ if vocabulary_size is None:
+ if not gfile.Exists(vocabulary_file):
+ raise ValueError('vocabulary_file in {} does not exist.'.format(key))
+
+ with gfile.GFile(vocabulary_file) as f:
+ vocabulary_size = sum(1 for _ in f)
+ logging.info(
+ 'vocabulary_size = %d in %s is inferred from the number of elements '
+ 'in the vocabulary_file %s.', vocabulary_size, key, vocabulary_file)
+
+ # `vocabulary_size` isn't required for lookup, but it is for `_num_buckets`.
+ if vocabulary_size < 1:
+ raise ValueError('Invalid vocabulary_size in {}.'.format(key))
+ if num_oov_buckets:
+ if default_value is not None:
+ raise ValueError(
+ 'Can\'t specify both num_oov_buckets and default_value in {}.'.format(
+ key))
+ if num_oov_buckets < 0:
+ raise ValueError('Invalid num_oov_buckets {} in {}.'.format(
+ num_oov_buckets, key))
+ _assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
+ _assert_key_is_string(key)
+ return VocabularyFileCategoricalColumn(
+ key=key,
+ vocabulary_file=vocabulary_file,
+ vocabulary_size=vocabulary_size,
+ num_oov_buckets=0 if num_oov_buckets is None else num_oov_buckets,
+ default_value=-1 if default_value is None else default_value,
+ dtype=dtype)
+
+
+def categorical_column_with_vocabulary_list(
+ key, vocabulary_list, dtype=None, default_value=-1, num_oov_buckets=0):
+ """A `_CategoricalColumn` with in-memory vocabulary.
+
+ Use this when your inputs are in string or integer format, and you have an
+ in-memory vocabulary mapping each value to an integer ID. By default,
+ out-of-vocabulary values are ignored. Use either (but not both) of
+ `num_oov_buckets` and `default_value` to specify how to include
+ out-of-vocabulary values.
+
+ For input dictionary `features`, `features[key]` is either `Tensor` or
+ `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
+ and `''` for string, which will be dropped by this feature column.
+
+ Example with `num_oov_buckets`:
+ In the following example, each input in `vocabulary_list` is assigned an ID
+ 0-3 corresponding to its index (e.g., input 'B' produces output 2). All other
+ inputs are hashed and assigned an ID 4-5.
+
+ ```python
+ colors = categorical_column_with_vocabulary_list(
+ key='colors', vocabulary_list=('R', 'G', 'B', 'Y'),
+ num_oov_buckets=2)
+ columns = [colors, ...]
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ linear_prediction, _, _ = linear_model(features, columns)
+ ```
+
+ Example with `default_value`:
+ In the following example, each input in `vocabulary_list` is assigned an ID
+ 0-4 corresponding to its index (e.g., input 'B' produces output 3). All other
+ inputs are assigned `default_value` 0.
+
+
+ ```python
+ colors = categorical_column_with_vocabulary_list(
+ key='colors', vocabulary_list=('X', 'R', 'G', 'B', 'Y'), default_value=0)
+ columns = [colors, ...]
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ linear_prediction, _, _ = linear_model(features, columns)
+ ```
+
+ And to make an embedding with either:
+
+ ```python
+ columns = [embedding_column(colors, 3),...]
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ dense_tensor = input_layer(features, columns)
+ ```
+
+ Args:
+ key: A unique string identifying the input feature. It is used as the
+ column name and the dictionary key for feature parsing configs, feature
+ `Tensor` objects, and feature columns.
+ vocabulary_list: An ordered iterable defining the vocabulary. Each feature
+ is mapped to the index of its value (if present) in `vocabulary_list`.
+ Must be castable to `dtype`.
+ dtype: The type of features. Only string and integer types are supported.
+ If `None`, it will be inferred from `vocabulary_list`.
+ default_value: The integer ID value to return for out-of-vocabulary feature
+ values, defaults to `-1`. This can not be specified with a positive
+ `num_oov_buckets`.
+ num_oov_buckets: Non-negative integer, the number of out-of-vocabulary
+ buckets. All out-of-vocabulary inputs will be assigned IDs in the range
+ `[len(vocabulary_list), len(vocabulary_list)+num_oov_buckets)` based on a
+ hash of the input value. A positive `num_oov_buckets` can not be specified
+ with `default_value`.
+
+ Returns:
+ A `CategoricalColumn` with in-memory vocabulary.
+
+ Raises:
+ ValueError: if `vocabulary_list` is empty, or contains duplicate keys.
+ ValueError: `num_oov_buckets` is a negative integer.
+ ValueError: `num_oov_buckets` and `default_value` are both specified.
+ ValueError: if `dtype` is not integer or string.
+ """
+ if (vocabulary_list is None) or (len(vocabulary_list) < 1):
+ raise ValueError(
+ 'vocabulary_list {} must be non-empty, column_name: {}'.format(
+ vocabulary_list, key))
+ if len(set(vocabulary_list)) != len(vocabulary_list):
+ raise ValueError(
+ 'Duplicate keys in vocabulary_list {}, column_name: {}'.format(
+ vocabulary_list, key))
+ vocabulary_dtype = dtypes.as_dtype(np.array(vocabulary_list).dtype)
+ if num_oov_buckets:
+ if default_value != -1:
+ raise ValueError(
+ 'Can\'t specify both num_oov_buckets and default_value in {}.'.format(
+ key))
+ if num_oov_buckets < 0:
+ raise ValueError('Invalid num_oov_buckets {} in {}.'.format(
+ num_oov_buckets, key))
+ _assert_string_or_int(
+ vocabulary_dtype, prefix='column_name: {} vocabulary'.format(key))
+ if dtype is None:
+ dtype = vocabulary_dtype
+ elif dtype.is_integer != vocabulary_dtype.is_integer:
+ raise ValueError(
+ 'dtype {} and vocabulary dtype {} do not match, column_name: {}'.format(
+ dtype, vocabulary_dtype, key))
+ _assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
+ _assert_key_is_string(key)
+
+ return VocabularyListCategoricalColumn(
+ key=key,
+ vocabulary_list=tuple(vocabulary_list),
+ dtype=dtype,
+ default_value=default_value,
+ num_oov_buckets=num_oov_buckets)
+
+
+def categorical_column_with_identity(key, num_buckets, default_value=None):
+ """A `CategoricalColumn` that returns identity values.
+
+ Use this when your inputs are integers in the range `[0, num_buckets)`, and
+ you want to use the input value itself as the categorical ID. Values outside
+ this range will result in `default_value` if specified, otherwise it will
+ fail.
+
+ Typically, this is used for contiguous ranges of integer indexes, but
+ it doesn't have to be. This might be inefficient, however, if many of IDs
+ are unused. Consider `categorical_column_with_hash_bucket` in that case.
+
+ For input dictionary `features`, `features[key]` is either `Tensor` or
+ `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
+ and `''` for string, which will be dropped by this feature column.
+
+ In the following examples, each input in the range `[0, 1000000)` is assigned
+ the same value. All other inputs are assigned `default_value` 0. Note that a
+ literal 0 in inputs will result in the same default ID.
+
+ Linear model:
+
+ ```python
+ video_id = categorical_column_with_identity(
+ key='video_id', num_buckets=1000000, default_value=0)
+ columns = [video_id, ...]
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ linear_prediction, _, _ = linear_model(features, columns)
+ ```
+
+ Embedding for a DNN model:
+
+ ```python
+ columns = [embedding_column(video_id, 9),...]
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ dense_tensor = input_layer(features, columns)
+ ```
+
+ Args:
+ key: A unique string identifying the input feature. It is used as the
+ column name and the dictionary key for feature parsing configs, feature
+ `Tensor` objects, and feature columns.
+ num_buckets: Range of inputs and outputs is `[0, num_buckets)`.
+ default_value: If `None`, this column's graph operations will fail for
+ out-of-range inputs. Otherwise, this value must be in the range
+ `[0, num_buckets)`, and will replace inputs in that range.
+
+ Returns:
+ A `CategoricalColumn` that returns identity values.
+
+ Raises:
+ ValueError: if `num_buckets` is less than one.
+ ValueError: if `default_value` is not in range `[0, num_buckets)`.
+ """
+ if num_buckets < 1:
+ raise ValueError(
+ 'num_buckets {} < 1, column_name {}'.format(num_buckets, key))
+ if (default_value is not None) and (
+ (default_value < 0) or (default_value >= num_buckets)):
+ raise ValueError(
+ 'default_value {} not in range [0, {}), column_name {}'.format(
+ default_value, num_buckets, key))
+ _assert_key_is_string(key)
+ return IdentityCategoricalColumn(
+ key=key, number_buckets=num_buckets, default_value=default_value)
+
+
+def indicator_column(categorical_column):
+ """Represents multi-hot representation of given categorical column.
+
+ - For DNN model, `indicator_column` can be used to wrap any
+ `categorical_column_*` (e.g., to feed to DNN). Consider to Use
+ `embedding_column` if the number of buckets/unique(values) are large.
+
+ - For Wide (aka linear) model, `indicator_column` is the internal
+ representation for categorical column when passing categorical column
+ directly (as any element in feature_columns) to `linear_model`. See
+ `linear_model` for details.
+
+ ```python
+ name = indicator_column(categorical_column_with_vocabulary_list(
+ 'name', ['bob', 'george', 'wanda'])
+ columns = [name, ...]
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ dense_tensor = input_layer(features, columns)
+
+ dense_tensor == [[1, 0, 0]] # If "name" bytes_list is ["bob"]
+ dense_tensor == [[1, 0, 1]] # If "name" bytes_list is ["bob", "wanda"]
+ dense_tensor == [[2, 0, 0]] # If "name" bytes_list is ["bob", "bob"]
+ ```
+
+ Args:
+ categorical_column: A `CategoricalColumn` which is created by
+ `categorical_column_with_*` or `crossed_column` functions.
+
+ Returns:
+ An `IndicatorColumn`.
+ """
+ return IndicatorColumn(categorical_column)
+
+
+def weighted_categorical_column(
+ categorical_column, weight_feature_key, dtype=dtypes.float32):
+ """Applies weight values to a `_CategoricalColumn`.
+
+ Use this when each of your sparse inputs has both an ID and a value. For
+ example, if you're representing text documents as a collection of word
+ frequencies, you can provide 2 parallel sparse input features ('terms' and
+ 'frequencies' below).
+
+ Example:
+
+ Input `tf.Example` objects:
+
+ ```proto
+ [
+ features {
+ feature {
+ key: "terms"
+ value {bytes_list {value: "very" value: "model"}}
+ }
+ feature {
+ key: "frequencies"
+ value {float_list {value: 0.3 value: 0.1}}
+ }
+ },
+ features {
+ feature {
+ key: "terms"
+ value {bytes_list {value: "when" value: "course" value: "human"}}
+ }
+ feature {
+ key: "frequencies"
+ value {float_list {value: 0.4 value: 0.1 value: 0.2}}
+ }
+ }
+ ]
+ ```
+
+ ```python
+ categorical_column = categorical_column_with_hash_bucket(
+ column_name='terms', hash_bucket_size=1000)
+ weighted_column = weighted_categorical_column(
+ categorical_column=categorical_column, weight_feature_key='frequencies')
+ columns = [weighted_column, ...]
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ linear_prediction, _, _ = linear_model(features, columns)
+ ```
+
+ This assumes the input dictionary contains a `SparseTensor` for key
+ 'terms', and a `SparseTensor` for key 'frequencies'. These 2 tensors must have
+ the same indices and dense shape.
+
+ Args:
+ categorical_column: A `_CategoricalColumn` created by
+ `categorical_column_with_*` functions.
+ weight_feature_key: String key for weight values.
+ dtype: Type of weights, such as `tf.float32`. Only float and integer weights
+ are supported.
+
+ Returns:
+ A `CategoricalColumn` composed of two sparse features: one represents id,
+ the other represents weight (value) of the id feature in that example.
+
+ Raises:
+ ValueError: if `dtype` is not convertible to float.
+ """
+ if (dtype is None) or not (dtype.is_integer or dtype.is_floating):
+ raise ValueError('dtype {} is not convertible to float.'.format(dtype))
+ return WeightedCategoricalColumn(
+ categorical_column=categorical_column,
+ weight_feature_key=weight_feature_key,
+ dtype=dtype)
+
+
+def crossed_column(keys, hash_bucket_size, hash_key=None):
+ """Returns a column for performing crosses of categorical features.
+
+ Crossed features will be hashed according to `hash_bucket_size`. Conceptually,
+ the transformation can be thought of as:
+ Hash(cartesian product of features) % `hash_bucket_size`
+
+ For example, if the input features are:
+
+ * SparseTensor referred by first key:
+
+ ```python
+ shape = [2, 2]
+ {
+ [0, 0]: "a"
+ [1, 0]: "b"
+ [1, 1]: "c"
+ }
+ ```
+
+ * SparseTensor referred by second key:
+
+ ```python
+ shape = [2, 1]
+ {
+ [0, 0]: "d"
+ [1, 0]: "e"
+ }
+ ```
+
+ then crossed feature will look like:
+
+ ```python
+ shape = [2, 2]
+ {
+ [0, 0]: Hash64("d", Hash64("a")) % hash_bucket_size
+ [1, 0]: Hash64("e", Hash64("b")) % hash_bucket_size
+ [1, 1]: Hash64("e", Hash64("c")) % hash_bucket_size
+ }
+ ```
+
+ Here is an example to create a linear model with crosses of string features:
+
+ ```python
+ keywords_x_doc_terms = crossed_column(['keywords', 'doc_terms'], 50K)
+ columns = [keywords_x_doc_terms, ...]
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ linear_prediction = linear_model(features, columns)
+ ```
+
+ You could also use vocabulary lookup before crossing:
+
+ ```python
+ keywords = categorical_column_with_vocabulary_file(
+ 'keywords', '/path/to/vocabulary/file', vocabulary_size=1K)
+ keywords_x_doc_terms = crossed_column([keywords, 'doc_terms'], 50K)
+ columns = [keywords_x_doc_terms, ...]
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ linear_prediction = linear_model(features, columns)
+ ```
+
+ If an input feature is of numeric type, you can use
+ `categorical_column_with_identity`, or `bucketized_column`, as in the example:
+
+ ```python
+ # vertical_id is an integer categorical feature.
+ vertical_id = categorical_column_with_identity('vertical_id', 10K)
+ price = numeric_column('price')
+ # bucketized_column converts numerical feature to a categorical one.
+ bucketized_price = bucketized_column(price, boundaries=[...])
+ vertical_id_x_price = crossed_column([vertical_id, bucketized_price], 50K)
+ columns = [vertical_id_x_price, ...]
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ linear_prediction = linear_model(features, columns)
+ ```
+
+ To use crossed column in DNN model, you need to add it in an embedding column
+ as in this example:
+
+ ```python
+ vertical_id_x_price = crossed_column([vertical_id, bucketized_price], 50K)
+ vertical_id_x_price_embedded = embedding_column(vertical_id_x_price, 10)
+ dense_tensor = input_layer(features, [vertical_id_x_price_embedded, ...])
+ ```
+
+ Args:
+ keys: An iterable identifying the features to be crossed. Each element can
+ be either:
+ * string: Will use the corresponding feature which must be of string type.
+ * `CategoricalColumn`: Will use the transformed tensor produced by this
+ column. Does not support hashed categorical column.
+ hash_bucket_size: An int > 1. The number of buckets.
+ hash_key: Specify the hash_key that will be used by the `FingerprintCat64`
+ function to combine the crosses fingerprints on SparseCrossOp (optional).
+
+ Returns:
+ A `CrossedColumn`.
+
+ Raises:
+ ValueError: If `len(keys) < 2`.
+ ValueError: If any of the keys is neither a string nor `CategoricalColumn`.
+ ValueError: If any of the keys is `HashedCategoricalColumn`.
+ ValueError: If `hash_bucket_size < 1`.
+ """
+ if not hash_bucket_size or hash_bucket_size < 1:
+ raise ValueError('hash_bucket_size must be > 1. '
+ 'hash_bucket_size: {}'.format(hash_bucket_size))
+ if not keys or len(keys) < 2:
+ raise ValueError(
+ 'keys must be a list with length > 1. Given: {}'.format(keys))
+ for key in keys:
+ if (not isinstance(key, six.string_types) and
+ not isinstance(key, CategoricalColumn)):
+ raise ValueError(
+ 'Unsupported key type. All keys must be either string, or '
+ 'categorical column except HashedCategoricalColumn. '
+ 'Given: {}'.format(key))
+ if isinstance(key, HashedCategoricalColumn):
+ raise ValueError(
+ 'categorical_column_with_hash_bucket is not supported for crossing. '
+ 'Hashing before crossing will increase probability of collision. '
+ 'Instead, use the feature name as a string. Given: {}'.format(key))
+ return CrossedColumn(
+ keys=tuple(keys), hash_bucket_size=hash_bucket_size, hash_key=hash_key)
+
+
+class StateManager(object):
+ """Manages the state associated with FeatureColumns.
+
+ Some `FeatureColumn`s create variables or resources to assist their
+ computation. The `StateManager` is responsible for creating and storing these
+ objects since `FeatureColumn`s are supposed to be stateless configuration
+ only.
+ """
+
+ def get_variable(self,
+ feature_column,
+ name,
+ shape,
+ dtype=None,
+ initializer=None):
+ """Creates a new variable or returns an existing one.
+
+ Args:
+ feature_column: A `FeatureColumn` object this variable corresponds to.
+ name: variable name.
+ shape: variable shape.
+ dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
+ initializer: initializer instance (callable).
+
+ Returns:
+ The variable.
+ """
+ raise NotImplementedError('StateManager.get_variable')
+
+ def get_resource(self, feature_column, name, resource_creator):
+ """Creates a new resource or returns an existing one.
+
+ Resources can be things such as tables etc.
+
+ Args:
+ feature_column: A `FeatureColumn` object this variable corresponds to.
+ name: Name of the resource.
+ resource_creator: A callable that can create the resource.
+
+ Returns:
+ The resource.
+ """
+ raise NotImplementedError('StateManager.get_resource')
+
+
+class FeatureColumn(object):
+ """Represents a feature column abstraction.
+
+ WARNING: Do not subclass this layer unless you know what you are doing:
+ the API is subject to future changes.
+
+ To distinguish between the concept of a feature family and a specific binary
+ feature within a family, we refer to a feature family like "country" as a
+ feature column. For example, we can have a feature in a `tf.Example` format:
+ {key: "country", value: [ "US" ]}
+ In this example the value of feature is "US" and "country" refers to the
+ column of the feature.
+
+ This class is an abstract class. Users should not create instances of this.
+ """
+ __metaclass__ = abc.ABCMeta
+
+ @abc.abstractproperty
+ def name(self):
+ """Returns string. Used for naming."""
+ pass
+
+ @abc.abstractmethod
+ def transform_feature(self, transformation_cache, state_manager):
+ """Returns intermediate representation (usually a `Tensor`).
+
+ Uses `transformation_cache` to create an intermediate representation
+ (usually a `Tensor`) that other feature columns can use.
+
+ Example usage of `transformation_cache`:
+ Let's say a Feature column depends on raw feature ('raw') and another
+ `FeatureColumn` (input_fc). To access corresponding `Tensor`s,
+ transformation_cache will be used as follows:
+
+ ```python
+ raw_tensor = transformation_cache.get('raw', state_manager)
+ fc_tensor = transformation_cache.get(input_fc, state_manager)
+ ```
+
+ Args:
+ transformation_cache: A `FeatureTransformationCache` object to access
+ features.
+ state_manager: A `StateManager` to create / access resources such as
+ lookup tables.
+
+ Returns:
+ Transformed feature `Tensor`.
+ """
+ pass
+
+ @abc.abstractproperty
+ def parse_example_spec(self):
+ """Returns a `tf.Example` parsing spec as dict.
+
+ It is used for get_parsing_spec for `tf.parse_example`. Returned spec is a
+ dict from keys ('string') to `VarLenFeature`, `FixedLenFeature`, and other
+ supported objects. Please check documentation of @{tf.parse_example} for all
+ supported spec objects.
+
+ Let's say a Feature column depends on raw feature ('raw') and another
+ `FeatureColumn` (input_fc). One possible implementation of
+ parse_example_spec is as follows:
+
+ ```python
+ spec = {'raw': tf.FixedLenFeature(...)}
+ spec.update(input_fc.parse_example_spec)
+ return spec
+ ```
+ """
+ pass
+
+ def create_state(self, state_manager):
+ """Uses the `state_manager` to create state for the FeatureColumn.
+
+ Args:
+ state_manager: A `StateManager` to create / access resources such as
+ lookup tables and variables.
+ """
+ pass
+
+
+class DenseColumn(FeatureColumn):
+ """Represents a column which can be represented as `Tensor`.
+
+ Some examples of this type are: numeric_column, embedding_column,
+ indicator_column.
+ """
+
+ __metaclass__ = abc.ABCMeta
+
+ @abc.abstractproperty
+ def variable_shape(self):
+ """`TensorShape` of `get_dense_tensor`, without batch dimension."""
+ pass
+
+ @abc.abstractmethod
+ def get_dense_tensor(self, transformation_cache, state_manager):
+ """Returns a `Tensor`.
+
+ The output of this function will be used by model-builder-functions. For
+ example the pseudo code of `input_layer` will be like:
+
+ ```python
+ def input_layer(features, feature_columns, ...):
+ outputs = [fc.get_dense_tensor(...) for fc in feature_columns]
+ return tf.concat(outputs)
+ ```
+
+ Args:
+ transformation_cache: A `FeatureTransformationCache` object to access
+ features.
+ state_manager: A `StateManager` to create / access resources such as
+ lookup tables.
+
+ Returns:
+ `Tensor` of shape [batch_size] + `variable_shape`.
+ """
+ pass
+
+
+def _create_weighted_sum(column,
+ transformation_cache,
+ state_manager,
+ units,
+ sparse_combiner,
+ weight_collections,
+ trainable,
+ weight_var=None):
+ """Creates a weighted sum for a dense/categorical column for linear_model."""
+ if isinstance(column, CategoricalColumn):
+ return _create_categorical_column_weighted_sum(
+ column=column,
+ transformation_cache=transformation_cache,
+ state_manager=state_manager,
+ units=units,
+ sparse_combiner=sparse_combiner,
+ weight_collections=weight_collections,
+ trainable=trainable,
+ weight_var=weight_var)
+ else:
+ return _create_dense_column_weighted_sum(
+ column=column,
+ transformation_cache=transformation_cache,
+ state_manager=state_manager,
+ units=units,
+ weight_collections=weight_collections,
+ trainable=trainable,
+ weight_var=weight_var)
+
+
+def _create_dense_column_weighted_sum(column,
+ transformation_cache,
+ state_manager,
+ units,
+ weight_collections,
+ trainable,
+ weight_var=None):
+ """Create a weighted sum of a dense column for linear_model."""
+ tensor = column.get_dense_tensor(transformation_cache, state_manager)
+ num_elements = column.variable_shape.num_elements()
+ batch_size = array_ops.shape(tensor)[0]
+ tensor = array_ops.reshape(tensor, shape=(batch_size, num_elements))
+ if weight_var is not None:
+ weight = weight_var
+ else:
+ weight = variable_scope.get_variable(
+ name='weights',
+ shape=[num_elements, units],
+ initializer=init_ops.zeros_initializer(),
+ trainable=trainable,
+ collections=weight_collections)
+ return math_ops.matmul(tensor, weight, name='weighted_sum')
+
+
+class CategoricalColumn(FeatureColumn):
+ """Represents a categorical feature.
+
+ A categorical feature typically handled with a @{tf.SparseTensor} of IDs.
+ """
+ __metaclass__ = abc.ABCMeta
+
+ IdWeightPair = collections.namedtuple( # pylint: disable=invalid-name
+ 'IdWeightPair', ('id_tensor', 'weight_tensor'))
+
+ @abc.abstractproperty
+ def num_buckets(self):
+ """Returns number of buckets in this sparse feature."""
+ pass
+
+ @abc.abstractmethod
+ def get_sparse_tensors(self, transformation_cache, state_manager):
+ """Returns an IdWeightPair.
+
+ `IdWeightPair` is a pair of `SparseTensor`s which represents ids and
+ weights.
+
+ `IdWeightPair.id_tensor` is typically a `batch_size` x `num_buckets`
+ `SparseTensor` of `int64`. `IdWeightPair.weight_tensor` is either a
+ `SparseTensor` of `float` or `None` to indicate all weights should be
+ taken to be 1. If specified, `weight_tensor` must have exactly the same
+ shape and indices as `sp_ids`. Expected `SparseTensor` is same as parsing
+ output of a `VarLenFeature` which is a ragged matrix.
+
+ Args:
+ transformation_cache: A `FeatureTransformationCache` object to access
+ features.
+ state_manager: A `StateManager` to create / access resources such as
+ lookup tables.
+ """
+ pass
+
+
+def _create_categorical_column_weighted_sum(column,
+ transformation_cache,
+ state_manager,
+ units,
+ sparse_combiner,
+ weight_collections,
+ trainable,
+ weight_var=None):
+ # pylint: disable=g-doc-return-or-yield,g-doc-args
+ """Create a weighted sum of a categorical column for linear_model.
+
+ Note to maintainer: As implementation details, the weighted sum is
+ implemented via embedding_lookup_sparse toward efficiency. Mathematically,
+ they are the same.
+
+ To be specific, conceptually, categorical column can be treated as multi-hot
+ vector. Say:
+
+ ```python
+ x = [0 0 1] # categorical column input
+ w = [a b c] # weights
+ ```
+ The weighted sum is `c` in this case, which is same as `w[2]`.
+
+ Another example is
+
+ ```python
+ x = [0 1 1] # categorical column input
+ w = [a b c] # weights
+ ```
+ The weighted sum is `b + c` in this case, which is same as `w[2] + w[3]`.
+
+ For both cases, we can implement weighted sum via embedding_lookup with
+ sparse_combiner = "sum".
+ """
+
+ sparse_tensors = column.get_sparse_tensors(transformation_cache,
+ state_manager)
+ id_tensor = sparse_ops.sparse_reshape(sparse_tensors.id_tensor, [
+ array_ops.shape(sparse_tensors.id_tensor)[0], -1
+ ])
+ weight_tensor = sparse_tensors.weight_tensor
+ if weight_tensor is not None:
+ weight_tensor = sparse_ops.sparse_reshape(
+ weight_tensor, [array_ops.shape(weight_tensor)[0], -1])
+
+ if weight_var is not None:
+ weight = weight_var
+ else:
+ weight = variable_scope.get_variable(
+ name='weights',
+ shape=(column.num_buckets, units),
+ initializer=init_ops.zeros_initializer(),
+ trainable=trainable,
+ collections=weight_collections)
+ return _safe_embedding_lookup_sparse(
+ weight,
+ id_tensor,
+ sparse_weights=weight_tensor,
+ combiner=sparse_combiner,
+ name='weighted_sum')
+
+
+class SequenceDenseColumn(FeatureColumn):
+ """Represents dense sequence data."""
+
+ __metaclass__ = abc.ABCMeta
+
+ TensorSequenceLengthPair = collections.namedtuple( # pylint: disable=invalid-name
+ 'TensorSequenceLengthPair', ('dense_tensor', 'sequence_length'))
+
+ @abc.abstractmethod
+ def get_sequence_dense_tensor(self, transformation_cache, state_manager):
+ """Returns a `TensorSequenceLengthPair`.
+
+ Args:
+ transformation_cache: A `FeatureTransformationCache` object to access
+ features.
+ state_manager: A `StateManager` to create / access resources such as
+ lookup tables.
+ """
+ pass
+
+
+class FeatureTransformationCache(object):
+ """Handles caching of transformations while building the model.
+
+ `FeatureColumn` specifies how to digest an input column to the network. Some
+ feature columns require data transformations. This class caches those
+ transformations.
+
+ Some features may be used in more than one place. For example, one can use a
+ bucketized feature by itself and a cross with it. In that case we
+ should create only one bucketization op instead of creating ops for each
+ feature column separately. To handle re-use of transformed columns,
+ `FeatureTransformationCache` caches all previously transformed columns.
+
+ Example:
+ We're trying to use the following `FeatureColumn`s:
+
+ ```python
+ bucketized_age = fc.bucketized_column(fc.numeric_column("age"), ...)
+ keywords = fc.categorical_column_with_hash_buckets("keywords", ...)
+ age_X_keywords = fc.crossed_column([bucketized_age, "keywords"])
+ ... = linear_model(features,
+ [bucketized_age, keywords, age_X_keywords]
+ ```
+
+ If we transform each column independently, then we'll get duplication of
+ bucketization (one for cross, one for bucketization itself).
+ The `FeatureTransformationCache` eliminates this duplication.
+ """
+
+ def __init__(self, features):
+ """Creates a `FeatureTransformationCache`.
+
+ Args:
+ features: A mapping from feature column to objects that are `Tensor` or
+ `SparseTensor`, or can be converted to same via
+ `sparse_tensor.convert_to_tensor_or_sparse_tensor`. A `string` key
+ signifies a base feature (not-transformed). A `FeatureColumn` key
+ means that this `Tensor` is the output of an existing `FeatureColumn`
+ which can be reused.
+ """
+ self._features = features.copy()
+ self._feature_tensors = {}
+
+ def get(self, key, state_manager):
+ """Returns a `Tensor` for the given key.
+
+ A `str` key is used to access a base feature (not-transformed). When a
+ `FeatureColumn` is passed, the transformed feature is returned if it
+ already exists, otherwise the given `FeatureColumn` is asked to provide its
+ transformed output, which is then cached.
+
+ Args:
+ key: a `str` or a `FeatureColumn`.
+ state_manager: A StateManager object that holds the FeatureColumn state.
+
+ Returns:
+ The transformed `Tensor` corresponding to the `key`.
+
+ Raises:
+ ValueError: if key is not found or a transformed `Tensor` cannot be
+ computed.
+ """
+ if key in self._feature_tensors:
+ # FeatureColumn is already transformed or converted.
+ return self._feature_tensors[key]
+
+ if key in self._features:
+ feature_tensor = self._get_raw_feature_as_tensor(key)
+ self._feature_tensors[key] = feature_tensor
+ return feature_tensor
+
+ if isinstance(key, six.string_types):
+ raise ValueError('Feature {} is not in features dictionary.'.format(key))
+
+ if not isinstance(key, FeatureColumn):
+ raise TypeError('"key" must be either a "str" or "FeatureColumn". '
+ 'Provided: {}'.format(key))
+
+ column = key
+ logging.debug('Transforming feature_column %s.', column)
+ transformed = column.transform_feature(self, state_manager)
+ if transformed is None:
+ raise ValueError('Column {} is not supported.'.format(column.name))
+ self._feature_tensors[column] = transformed
+ return transformed
+
+ def _get_raw_feature_as_tensor(self, key):
+ """Gets the raw_feature (keyed by `key`) as `tensor`.
+
+ The raw feature is converted to (sparse) tensor and maybe expand dim.
+
+ For both `Tensor` and `SparseTensor`, the rank will be expanded (to 2) if
+ the rank is 1. This supports dynamic rank also. For rank 0 raw feature, will
+ error out as it is not supported.
+
+ Args:
+ key: A `str` key to access the raw feature.
+
+ Returns:
+ A `Tensor` or `SparseTensor`.
+
+ Raises:
+ ValueError: if the raw feature has rank 0.
+ """
+ raw_feature = self._features[key]
+ feature_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(
+ raw_feature)
+
+ def expand_dims(input_tensor):
+ # Input_tensor must have rank 1.
+ if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
+ return sparse_ops.sparse_reshape(
+ input_tensor, [array_ops.shape(input_tensor)[0], -1])
+ else:
+ return array_ops.expand_dims(input_tensor, -1)
+
+ rank = feature_tensor.get_shape().ndims
+ if rank is not None:
+ if rank == 0:
+ raise ValueError(
+ 'Feature (key: {}) cannot have rank 0. Give: {}'.format(
+ key, feature_tensor))
+ return feature_tensor if rank != 1 else expand_dims(feature_tensor)
+
+ # Handle dynamic rank.
+ with ops.control_dependencies([
+ check_ops.assert_positive(
+ array_ops.rank(feature_tensor),
+ message='Feature (key: {}) cannot have rank 0. Given: {}'.format(
+ key, feature_tensor))]):
+ return control_flow_ops.cond(
+ math_ops.equal(1, array_ops.rank(feature_tensor)),
+ lambda: expand_dims(feature_tensor),
+ lambda: feature_tensor)
+
+
+# TODO(ptucker): Move to third_party/tensorflow/python/ops/sparse_ops.py
+def _shape_offsets(shape):
+ """Returns moving offset for each dimension given shape."""
+ offsets = []
+ for dim in reversed(shape):
+ if offsets:
+ offsets.append(dim * offsets[-1])
+ else:
+ offsets.append(dim)
+ offsets.reverse()
+ return offsets
+
+
+# TODO(ptucker): Move to third_party/tensorflow/python/ops/sparse_ops.py
+def _to_sparse_input_and_drop_ignore_values(input_tensor, ignore_value=None):
+ """Converts a `Tensor` to a `SparseTensor`, dropping ignore_value cells.
+
+ If `input_tensor` is already a `SparseTensor`, just return it.
+
+ Args:
+ input_tensor: A string or integer `Tensor`.
+ ignore_value: Entries in `dense_tensor` equal to this value will be
+ absent from the resulting `SparseTensor`. If `None`, default value of
+ `dense_tensor`'s dtype will be used ('' for `str`, -1 for `int`).
+
+ Returns:
+ A `SparseTensor` with the same shape as `input_tensor`.
+
+ Raises:
+ ValueError: when `input_tensor`'s rank is `None`.
+ """
+ input_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(
+ input_tensor)
+ if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
+ return input_tensor
+ with ops.name_scope(None, 'to_sparse_input', (input_tensor, ignore_value,)):
+ if ignore_value is None:
+ if input_tensor.dtype == dtypes.string:
+ # Exception due to TF strings are converted to numpy objects by default.
+ ignore_value = ''
+ elif input_tensor.dtype.is_integer:
+ ignore_value = -1 # -1 has a special meaning of missing feature
+ else:
+ # NOTE: `as_numpy_dtype` is a property, so with the parentheses this is
+ # constructing a new numpy object of the given type, which yields the
+ # default value for that type.
+ ignore_value = input_tensor.dtype.as_numpy_dtype()
+ ignore_value = math_ops.cast(
+ ignore_value, input_tensor.dtype, name='ignore_value')
+ indices = array_ops.where(
+ math_ops.not_equal(input_tensor, ignore_value), name='indices')
+ return sparse_tensor_lib.SparseTensor(
+ indices=indices,
+ values=array_ops.gather_nd(input_tensor, indices, name='values'),
+ dense_shape=array_ops.shape(
+ input_tensor, out_type=dtypes.int64, name='dense_shape'))
+
+
+def _normalize_feature_columns(feature_columns):
+ """Normalizes the `feature_columns` input.
+
+ This method converts the `feature_columns` to list type as best as it can. In
+ addition, verifies the type and other parts of feature_columns, required by
+ downstream library.
+
+ Args:
+ feature_columns: The raw feature columns, usually passed by users.
+
+ Returns:
+ The normalized feature column list.
+
+ Raises:
+ ValueError: for any invalid inputs, such as empty, duplicated names, etc.
+ """
+ if isinstance(feature_columns, FeatureColumn):
+ feature_columns = [feature_columns]
+
+ if isinstance(feature_columns, collections.Iterator):
+ feature_columns = list(feature_columns)
+
+ if isinstance(feature_columns, dict):
+ raise ValueError('Expected feature_columns to be iterable, found dict.')
+
+ for column in feature_columns:
+ if not isinstance(column, FeatureColumn):
+ raise ValueError('Items of feature_columns must be a FeatureColumn. '
+ 'Given (type {}): {}.'.format(type(column), column))
+ if not feature_columns:
+ raise ValueError('feature_columns must not be empty.')
+ name_to_column = dict()
+ for column in feature_columns:
+ if column.name in name_to_column:
+ raise ValueError('Duplicate feature column name found for columns: {} '
+ 'and {}. This usually means that these columns refer to '
+ 'same base feature. Either one must be discarded or a '
+ 'duplicated but renamed item must be inserted in '
+ 'features dict.'.format(column,
+ name_to_column[column.name]))
+ name_to_column[column.name] = column
+
+ return feature_columns
+
+
+class NumericColumn(
+ DenseColumn,
+ collections.namedtuple(
+ 'NumericColumn',
+ ('key', 'shape', 'default_value', 'dtype', 'normalizer_fn'))):
+ """see `numeric_column`."""
+
+ @property
+ def name(self):
+ """See `FeatureColumn` base class."""
+ return self.key
+
+ @property
+ def parse_example_spec(self):
+ """See `FeatureColumn` base class."""
+ return {
+ self.key:
+ parsing_ops.FixedLenFeature(self.shape, self.dtype,
+ self.default_value)
+ }
+
+ def transform_feature(self, transformation_cache, state_manager):
+ """See `FeatureColumn` base class.
+
+ In this case, we apply the `normalizer_fn` to the input tensor.
+
+ Args:
+ transformation_cache: A `FeatureTransformationCache` object to access
+ features.
+ state_manager: A `StateManager` to create / access resources such as
+ lookup tables.
+
+ Returns:
+ Normalized input tensor.
+ Raises:
+ ValueError: If a SparseTensor is passed in.
+ """
+ input_tensor = transformation_cache.get(self.key, state_manager)
+ if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
+ raise ValueError(
+ 'The corresponding Tensor of numerical column must be a Tensor. '
+ 'SparseTensor is not supported. key: {}'.format(self.key))
+ if self.normalizer_fn is not None:
+ input_tensor = self.normalizer_fn(input_tensor)
+ return math_ops.to_float(input_tensor)
+
+ @property
+ def variable_shape(self):
+ """See `DenseColumn` base class."""
+ return tensor_shape.TensorShape(self.shape)
+
+ def get_dense_tensor(self, transformation_cache, state_manager):
+ """Returns dense `Tensor` representing numeric feature.
+
+ Args:
+ transformation_cache: A `FeatureTransformationCache` object to access
+ features.
+ state_manager: A `StateManager` to create / access resources such as
+ lookup tables.
+
+ Returns:
+ Dense `Tensor` created within `transform_feature`.
+ """
+ # Feature has been already transformed. Return the intermediate
+ # representation created by _transform_feature.
+ return transformation_cache.get(self, state_manager)
+
+
+class BucketizedColumn(DenseColumn, CategoricalColumn,
+ collections.namedtuple('BucketizedColumn',
+ ('source_column', 'boundaries'))):
+ """See `bucketized_column`."""
+
+ @property
+ def name(self):
+ """See `FeatureColumn` base class."""
+ return '{}_bucketized'.format(self.source_column.name)
+
+ @property
+ def parse_example_spec(self):
+ """See `FeatureColumn` base class."""
+ return self.source_column.parse_example_spec
+
+ def transform_feature(self, transformation_cache, state_manager):
+ """Returns bucketized categorical `source_column` tensor."""
+ source_tensor = transformation_cache.get(self.source_column, state_manager)
+ return math_ops._bucketize( # pylint: disable=protected-access
+ source_tensor,
+ boundaries=self.boundaries)
+
+ @property
+ def variable_shape(self):
+ """See `DenseColumn` base class."""
+ return tensor_shape.TensorShape(
+ tuple(self.source_column.shape) + (len(self.boundaries) + 1,))
+
+ def get_dense_tensor(self, transformation_cache, state_manager):
+ """Returns one hot encoded dense `Tensor`."""
+ input_tensor = transformation_cache.get(self, state_manager)
+ return array_ops.one_hot(
+ indices=math_ops.to_int64(input_tensor),
+ depth=len(self.boundaries) + 1,
+ on_value=1.,
+ off_value=0.)
+
+ @property
+ def num_buckets(self):
+ """See `CategoricalColumn` base class."""
+ # By construction, source_column is always one-dimensional.
+ return (len(self.boundaries) + 1) * self.source_column.shape[0]
+
+ def get_sparse_tensors(self, transformation_cache, state_manager):
+ """Converts dense inputs to SparseTensor so downstream code can use it."""
+ input_tensor = transformation_cache.get(self, state_manager)
+ batch_size = array_ops.shape(input_tensor)[0]
+ # By construction, source_column is always one-dimensional.
+ source_dimension = self.source_column.shape[0]
+
+ i1 = array_ops.reshape(
+ array_ops.tile(
+ array_ops.expand_dims(math_ops.range(0, batch_size), 1),
+ [1, source_dimension]),
+ (-1,))
+ i2 = array_ops.tile(math_ops.range(0, source_dimension), [batch_size])
+ # Flatten the bucket indices and unique them across dimensions
+ # E.g. 2nd dimension indices will range from k to 2*k-1 with k buckets
+ bucket_indices = (
+ array_ops.reshape(input_tensor, (-1,)) +
+ (len(self.boundaries) + 1) * i2)
+
+ indices = math_ops.to_int64(array_ops.transpose(array_ops.stack((i1, i2))))
+ dense_shape = math_ops.to_int64(array_ops.stack(
+ [batch_size, source_dimension]))
+ sparse_tensor = sparse_tensor_lib.SparseTensor(
+ indices=indices,
+ values=bucket_indices,
+ dense_shape=dense_shape)
+ return CategoricalColumn.IdWeightPair(sparse_tensor, None)
+
+
+class EmbeddingColumn(
+ DenseColumn, SequenceDenseColumn,
+ collections.namedtuple(
+ 'EmbeddingColumn',
+ ('categorical_column', 'dimension', 'combiner', 'initializer',
+ 'ckpt_to_load_from', 'tensor_name_in_ckpt', 'max_norm', 'trainable'))):
+ """See `embedding_column`."""
+
+ @property
+ def name(self):
+ """See `FeatureColumn` base class."""
+ return '{}_embedding'.format(self.categorical_column.name)
+
+ @property
+ def parse_example_spec(self):
+ """See `FeatureColumn` base class."""
+ return self.categorical_column.parse_example_spec
+
+ def transform_feature(self, transformation_cache, state_manager):
+ """Transforms underlying `categorical_column`."""
+ return transformation_cache.get(self.categorical_column, state_manager)
+
+ @property
+ def variable_shape(self):
+ """See `DenseColumn` base class."""
+ return tensor_shape.vector(self.dimension)
+
+ def _get_dense_tensor_internal(self, transformation_cache, state_manager):
+ """Private method that follows the signature of _get_dense_tensor."""
+ # Get sparse IDs and weights.
+ sparse_tensors = self.categorical_column.get_sparse_tensors(
+ transformation_cache, state_manager)
+ sparse_ids = sparse_tensors.id_tensor
+ sparse_weights = sparse_tensors.weight_tensor
+
+ embedding_shape = (self.categorical_column.num_buckets, self.dimension)
+ embedding_weights = state_manager.get_variable(
+ self,
+ name='embedding_weights',
+ shape=embedding_shape,
+ dtype=dtypes.float32,
+ initializer=self.initializer)
+
+ if self.ckpt_to_load_from is not None:
+ to_restore = embedding_weights
+ if isinstance(to_restore, variables.PartitionedVariable):
+ to_restore = to_restore._get_variable_list() # pylint: disable=protected-access
+ checkpoint_utils.init_from_checkpoint(self.ckpt_to_load_from, {
+ self.tensor_name_in_ckpt: to_restore
+ })
+
+ # Return embedding lookup result.
+ return _safe_embedding_lookup_sparse(
+ embedding_weights=embedding_weights,
+ sparse_ids=sparse_ids,
+ sparse_weights=sparse_weights,
+ combiner=self.combiner,
+ name='%s_weights' % self.name,
+ max_norm=self.max_norm)
+
+ def get_dense_tensor(self, transformation_cache, state_manager):
+ """Returns tensor after doing the embedding lookup.
+
+ Args:
+ transformation_cache: A `FeatureTransformationCache` object to access
+ features.
+ state_manager: A `StateManager` to create / access resources such as
+ lookup tables.
+
+ Returns:
+ Embedding lookup tensor.
+
+ Raises:
+ ValueError: `categorical_column` is SequenceCategoricalColumn.
+ """
+ if isinstance(self.categorical_column, SequenceCategoricalColumn):
+ raise ValueError(
+ 'In embedding_column: {}. '
+ 'categorical_column must not be of type SequenceCategoricalColumn. '
+ 'Suggested fix A: If you wish to use input_layer, use a '
+ 'non-sequence categorical_column_with_*. '
+ 'Suggested fix B: If you wish to create sequence input, use '
+ 'sequence_input_layer instead of input_layer. '
+ 'Given (type {}): {}'.format(self.name, type(self.categorical_column),
+ self.categorical_column))
+ return self._get_dense_tensor_internal(transformation_cache, state_manager)
+
+ def get_sequence_dense_tensor(self, transformation_cache, state_manager):
+ """See `SequenceDenseColumn` base class."""
+ if not isinstance(self.categorical_column, SequenceCategoricalColumn):
+ raise ValueError(
+ 'In embedding_column: {}. '
+ 'categorical_column must be of type SequenceCategoricalColumn '
+ 'to use sequence_input_layer. '
+ 'Suggested fix: Use one of sequence_categorical_column_with_*. '
+ 'Given (type {}): {}'.format(self.name, type(self.categorical_column),
+ self.categorical_column))
+ dense_tensor = self._get_dense_tensor_internal( # pylint: disable=protected-access
+ transformation_cache, state_manager)
+ sparse_tensors = self.categorical_column.get_sparse_tensors(
+ transformation_cache, state_manager)
+ sequence_length = _sequence_length_from_sparse_tensor(
+ sparse_tensors.id_tensor)
+ return SequenceDenseColumn.TensorSequenceLengthPair(
+ dense_tensor=dense_tensor, sequence_length=sequence_length)
+
+
+def _get_graph_for_variable(var):
+ if isinstance(var, variables.PartitionedVariable):
+ return list(var)[0].graph
+ else:
+ return var.graph
+
+
+class SharedEmbeddingColumn(
+ DenseColumn, SequenceDenseColumn,
+ collections.namedtuple(
+ 'SharedEmbeddingColumn',
+ ('categorical_column', 'dimension', 'combiner', 'initializer',
+ 'shared_embedding_collection_name', 'ckpt_to_load_from',
+ 'tensor_name_in_ckpt', 'max_norm', 'trainable'))):
+ """See `embedding_column`."""
+
+ @property
+ def name(self):
+ """See `FeatureColumn` base class."""
+ return '{}_shared_embedding'.format(self.categorical_column.name)
+
+ @property
+ def shared_collection_name(self):
+ """Returns the shared name of this column.
+
+ A group of columns share an embedding. Each one of those columns would have
+ the same `shared_collection_name` by which they could be collectively
+ referred to.
+ """
+ return self.shared_embedding_collection_name
+
+ @property
+ def parse_example_spec(self):
+ """See `FeatureColumn` base class."""
+ return self.categorical_column.parse_example_spec
+
+ def transform_feature(self, transformation_cache, state_manager):
+ """See `FeatureColumn` base class."""
+ return transformation_cache.get(self.categorical_column, state_manager)
+
+ @property
+ def variable_shape(self):
+ """See `DenseColumn` base class."""
+ return tensor_shape.vector(self.dimension)
+
+ def _get_dense_tensor_internal(self, transformation_cache, state_manager):
+ """Private method that follows the signature of _get_dense_tensor."""
+ # This method is called from a variable_scope with name _var_scope_name,
+ # which is shared among all shared embeddings. Open a name_scope here, so
+ # that the ops for different columns have distinct names.
+ with ops.name_scope(None, default_name=self.name):
+ # Get sparse IDs and weights.
+ sparse_tensors = self.categorical_column.get_sparse_tensors(
+ transformation_cache, state_manager)
+ sparse_ids = sparse_tensors.id_tensor
+ sparse_weights = sparse_tensors.weight_tensor
+
+ embedding_shape = (self.categorical_column.num_buckets, self.dimension)
+ embedding_weights = state_manager.get_variable(
+ self,
+ name='embedding_weights',
+ shape=embedding_shape,
+ dtype=dtypes.float32,
+ initializer=self.initializer)
+
+ if self.ckpt_to_load_from is not None:
+ to_restore = embedding_weights
+ if isinstance(to_restore, variables.PartitionedVariable):
+ to_restore = to_restore._get_variable_list() # pylint: disable=protected-access
+ checkpoint_utils.init_from_checkpoint(self.ckpt_to_load_from, {
+ self.tensor_name_in_ckpt: to_restore
+ })
+
+ # Return embedding lookup result.
+ return _safe_embedding_lookup_sparse(
+ embedding_weights=embedding_weights,
+ sparse_ids=sparse_ids,
+ sparse_weights=sparse_weights,
+ combiner=self.combiner,
+ name='%s_weights' % self.name,
+ max_norm=self.max_norm)
+
+ def get_dense_tensor(self, transformation_cache, state_manager):
+ """Returns the embedding lookup result."""
+ if isinstance(self.categorical_column, SequenceCategoricalColumn):
+ raise ValueError(
+ 'In embedding_column: {}. '
+ 'categorical_column must not be of type SequenceCategoricalColumn. '
+ 'Suggested fix A: If you wish to use input_layer, use a '
+ 'non-sequence categorical_column_with_*. '
+ 'Suggested fix B: If you wish to create sequence input, use '
+ 'sequence_input_layer instead of input_layer. '
+ 'Given (type {}): {}'.format(self.name, type(self.categorical_column),
+ self.categorical_column))
+ return self._get_dense_tensor_internal(transformation_cache, state_manager)
+
+ def get_sequence_dense_tensor(self, transformation_cache, state_manager):
+ """See `SequenceDenseColumn` base class."""
+ if not isinstance(self.categorical_column, SequenceCategoricalColumn):
+ raise ValueError(
+ 'In embedding_column: {}. '
+ 'categorical_column must be of type SequenceCategoricalColumn '
+ 'to use sequence_input_layer. '
+ 'Suggested fix: Use one of sequence_categorical_column_with_*. '
+ 'Given (type {}): {}'.format(self.name, type(self.categorical_column),
+ self.categorical_column))
+ dense_tensor = self.get_dense_tensor_internal(transformation_cache,
+ state_manager)
+ sparse_tensors = self.categorical_column.get_sparse_tensors(
+ transformation_cache, state_manager)
+ sequence_length = _sequence_length_from_sparse_tensor(
+ sparse_tensors.id_tensor)
+ return SequenceDenseColumn.TensorSequenceLengthPair(
+ dense_tensor=dense_tensor, sequence_length=sequence_length)
+
+
+def _create_tuple(shape, value):
+ """Returns a tuple with given shape and filled with value."""
+ if shape:
+ return tuple([_create_tuple(shape[1:], value) for _ in range(shape[0])])
+ return value
+
+
+def _as_tuple(value):
+ if not nest.is_sequence(value):
+ return value
+ return tuple([_as_tuple(v) for v in value])
+
+
+def _check_shape(shape, key):
+ """Returns shape if it's valid, raises error otherwise."""
+ assert shape is not None
+ if not nest.is_sequence(shape):
+ shape = [shape]
+ shape = tuple(shape)
+ for dimension in shape:
+ if not isinstance(dimension, int):
+ raise TypeError('shape dimensions must be integer. '
+ 'shape: {}, key: {}'.format(shape, key))
+ if dimension < 1:
+ raise ValueError('shape dimensions must be greater than 0. '
+ 'shape: {}, key: {}'.format(shape, key))
+ return shape
+
+
+def _is_shape_and_default_value_compatible(default_value, shape):
+ """Verifies compatibility of shape and default_value."""
+ # Invalid condition:
+ # * if default_value is not a scalar and shape is empty
+ # * or if default_value is an iterable and shape is not empty
+ if nest.is_sequence(default_value) != bool(shape):
+ return False
+ if not shape:
+ return True
+ if len(default_value) != shape[0]:
+ return False
+ for i in range(shape[0]):
+ if not _is_shape_and_default_value_compatible(default_value[i], shape[1:]):
+ return False
+ return True
+
+
+def _check_default_value(shape, default_value, dtype, key):
+ """Returns default value as tuple if it's valid, otherwise raises errors.
+
+ This function verifies that `default_value` is compatible with both `shape`
+ and `dtype`. If it is not compatible, it raises an error. If it is compatible,
+ it casts default_value to a tuple and returns it. `key` is used only
+ for error message.
+
+ Args:
+ shape: An iterable of integers specifies the shape of the `Tensor`.
+ default_value: If a single value is provided, the same value will be applied
+ as the default value for every item. If an iterable of values is
+ provided, the shape of the `default_value` should be equal to the given
+ `shape`.
+ dtype: defines the type of values. Default value is `tf.float32`. Must be a
+ non-quantized, real integer or floating point type.
+ key: Column name, used only for error messages.
+
+ Returns:
+ A tuple which will be used as default value.
+
+ Raises:
+ TypeError: if `default_value` is an iterable but not compatible with `shape`
+ TypeError: if `default_value` is not compatible with `dtype`.
+ ValueError: if `dtype` is not convertible to `tf.float32`.
+ """
+ if default_value is None:
+ return None
+
+ if isinstance(default_value, int):
+ return _create_tuple(shape, default_value)
+
+ if isinstance(default_value, float) and dtype.is_floating:
+ return _create_tuple(shape, default_value)
+
+ if callable(getattr(default_value, 'tolist', None)): # Handles numpy arrays
+ default_value = default_value.tolist()
+
+ if nest.is_sequence(default_value):
+ if not _is_shape_and_default_value_compatible(default_value, shape):
+ raise ValueError(
+ 'The shape of default_value must be equal to given shape. '
+ 'default_value: {}, shape: {}, key: {}'.format(
+ default_value, shape, key))
+ # Check if the values in the list are all integers or are convertible to
+ # floats.
+ is_list_all_int = all(
+ isinstance(v, int) for v in nest.flatten(default_value))
+ is_list_has_float = any(
+ isinstance(v, float) for v in nest.flatten(default_value))
+ if is_list_all_int:
+ return _as_tuple(default_value)
+ if is_list_has_float and dtype.is_floating:
+ return _as_tuple(default_value)
+ raise TypeError('default_value must be compatible with dtype. '
+ 'default_value: {}, dtype: {}, key: {}'.format(
+ default_value, dtype, key))
+
+
+class HashedCategoricalColumn(
+ CategoricalColumn,
+ collections.namedtuple('HashedCategoricalColumn',
+ ('key', 'hash_bucket_size', 'dtype'))):
+ """see `categorical_column_with_hash_bucket`."""
+
+ @property
+ def name(self):
+ """See `FeatureColumn` base class."""
+ return self.key
+
+ @property
+ def parse_example_spec(self):
+ """See `FeatureColumn` base class."""
+ return {self.key: parsing_ops.VarLenFeature(self.dtype)}
+
+ def transform_feature(self, transformation_cache, state_manager):
+ """Hashes the values in the feature_column."""
+ input_tensor = _to_sparse_input_and_drop_ignore_values(
+ transformation_cache.get(self.key, state_manager))
+ if not isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
+ raise ValueError('SparseColumn input must be a SparseTensor.')
+
+ _assert_string_or_int(
+ input_tensor.dtype,
+ prefix='column_name: {} input_tensor'.format(self.key))
+
+ if self.dtype.is_integer != input_tensor.dtype.is_integer:
+ raise ValueError(
+ 'Column dtype and SparseTensors dtype must be compatible. '
+ 'key: {}, column dtype: {}, tensor dtype: {}'.format(
+ self.key, self.dtype, input_tensor.dtype))
+
+ if self.dtype == dtypes.string:
+ sparse_values = input_tensor.values
+ else:
+ sparse_values = string_ops.as_string(input_tensor.values)
+
+ sparse_id_values = string_ops.string_to_hash_bucket_fast(
+ sparse_values, self.hash_bucket_size, name='lookup')
+ return sparse_tensor_lib.SparseTensor(
+ input_tensor.indices, sparse_id_values, input_tensor.dense_shape)
+
+ @property
+ def num_buckets(self):
+ """Returns number of buckets in this sparse feature."""
+ return self.hash_bucket_size
+
+ def get_sparse_tensors(self, transformation_cache, state_manager):
+ """See `CategoricalColumn` base class."""
+ return CategoricalColumn.IdWeightPair(
+ transformation_cache.get(self, state_manager), None)
+
+
+class VocabularyFileCategoricalColumn(
+ CategoricalColumn,
+ collections.namedtuple('VocabularyFileCategoricalColumn',
+ ('key', 'vocabulary_file', 'vocabulary_size',
+ 'num_oov_buckets', 'dtype', 'default_value'))):
+ """See `categorical_column_with_vocabulary_file`."""
+
+ @property
+ def name(self):
+ """See `FeatureColumn` base class."""
+ return self.key
+
+ @property
+ def parse_example_spec(self):
+ """See `FeatureColumn` base class."""
+ return {self.key: parsing_ops.VarLenFeature(self.dtype)}
+
+ def transform_feature(self, transformation_cache, state_manager):
+ """Creates a lookup table for the vocabulary."""
+ input_tensor = _to_sparse_input_and_drop_ignore_values(
+ transformation_cache.get(self.key, state_manager))
+
+ if self.dtype.is_integer != input_tensor.dtype.is_integer:
+ raise ValueError(
+ 'Column dtype and SparseTensors dtype must be compatible. '
+ 'key: {}, column dtype: {}, tensor dtype: {}'.format(
+ self.key, self.dtype, input_tensor.dtype))
+
+ _assert_string_or_int(
+ input_tensor.dtype,
+ prefix='column_name: {} input_tensor'.format(self.key))
+
+ key_dtype = self.dtype
+ if input_tensor.dtype.is_integer:
+ # `index_table_from_file` requires 64-bit integer keys.
+ key_dtype = dtypes.int64
+ input_tensor = math_ops.to_int64(input_tensor)
+
+ # TODO(rohanj): Use state manager to manage the index table creation.
+ return lookup_ops.index_table_from_file(
+ vocabulary_file=self.vocabulary_file,
+ num_oov_buckets=self.num_oov_buckets,
+ vocab_size=self.vocabulary_size,
+ default_value=self.default_value,
+ key_dtype=key_dtype,
+ name='{}_lookup'.format(self.key)).lookup(input_tensor)
+
+ @property
+ def num_buckets(self):
+ """Returns number of buckets in this sparse feature."""
+ return self.vocabulary_size + self.num_oov_buckets
+
+ def get_sparse_tensors(self, transformation_cache, state_manager):
+ """See `CategoricalColumn` base class."""
+ return CategoricalColumn.IdWeightPair(
+ transformation_cache.get(self, state_manager), None)
+
+
+class VocabularyListCategoricalColumn(
+ CategoricalColumn,
+ collections.namedtuple(
+ 'VocabularyListCategoricalColumn',
+ ('key', 'vocabulary_list', 'dtype', 'default_value', 'num_oov_buckets'))
+):
+ """See `categorical_column_with_vocabulary_list`."""
+
+ @property
+ def name(self):
+ """See `FeatureColumn` base class."""
+ return self.key
+
+ @property
+ def parse_example_spec(self):
+ """See `FeatureColumn` base class."""
+ return {self.key: parsing_ops.VarLenFeature(self.dtype)}
+
+ def transform_feature(self, transformation_cache, state_manager):
+ """Creates a lookup table for the vocabulary list."""
+ input_tensor = _to_sparse_input_and_drop_ignore_values(
+ transformation_cache.get(self.key, state_manager))
+
+ if self.dtype.is_integer != input_tensor.dtype.is_integer:
+ raise ValueError(
+ 'Column dtype and SparseTensors dtype must be compatible. '
+ 'key: {}, column dtype: {}, tensor dtype: {}'.format(
+ self.key, self.dtype, input_tensor.dtype))
+
+ _assert_string_or_int(
+ input_tensor.dtype,
+ prefix='column_name: {} input_tensor'.format(self.key))
+
+ key_dtype = self.dtype
+ if input_tensor.dtype.is_integer:
+ # `index_table_from_tensor` requires 64-bit integer keys.
+ key_dtype = dtypes.int64
+ input_tensor = math_ops.to_int64(input_tensor)
+
+ # TODO(rohanj): Use state manager to manage the index table creation.
+ return lookup_ops.index_table_from_tensor(
+ vocabulary_list=tuple(self.vocabulary_list),
+ default_value=self.default_value,
+ num_oov_buckets=self.num_oov_buckets,
+ dtype=key_dtype,
+ name='{}_lookup'.format(self.key)).lookup(input_tensor)
+
+ @property
+ def num_buckets(self):
+ """Returns number of buckets in this sparse feature."""
+ return len(self.vocabulary_list) + self.num_oov_buckets
+
+ def get_sparse_tensors(self, transformation_cache, state_manager):
+ """See `CategoricalColumn` base class."""
+ return CategoricalColumn.IdWeightPair(
+ transformation_cache.get(self, state_manager), None)
+
+
+class IdentityCategoricalColumn(
+ CategoricalColumn,
+ collections.namedtuple('IdentityCategoricalColumn',
+ ('key', 'number_buckets', 'default_value'))):
+
+ """See `categorical_column_with_identity`."""
+
+ @property
+ def name(self):
+ """See `FeatureColumn` base class."""
+ return self.key
+
+ @property
+ def parse_example_spec(self):
+ """See `FeatureColumn` base class."""
+ return {self.key: parsing_ops.VarLenFeature(dtypes.int64)}
+
+ def transform_feature(self, transformation_cache, state_manager):
+ """Returns a SparseTensor with identity values."""
+ input_tensor = _to_sparse_input_and_drop_ignore_values(
+ transformation_cache.get(self.key, state_manager))
+
+ if not input_tensor.dtype.is_integer:
+ raise ValueError(
+ 'Invalid input, not integer. key: {} dtype: {}'.format(
+ self.key, input_tensor.dtype))
+
+ values = math_ops.to_int64(input_tensor.values, name='values')
+ num_buckets = math_ops.to_int64(self.num_buckets, name='num_buckets')
+ zero = math_ops.to_int64(0, name='zero')
+ if self.default_value is None:
+ # Fail if values are out-of-range.
+ assert_less = check_ops.assert_less(
+ values, num_buckets, data=(values, num_buckets),
+ name='assert_less_than_num_buckets')
+ assert_greater = check_ops.assert_greater_equal(
+ values, zero, data=(values,),
+ name='assert_greater_or_equal_0')
+ with ops.control_dependencies((assert_less, assert_greater)):
+ values = array_ops.identity(values)
+ else:
+ # Assign default for out-of-range values.
+ values = array_ops.where(
+ math_ops.logical_or(
+ values < zero, values >= num_buckets, name='out_of_range'),
+ array_ops.fill(
+ dims=array_ops.shape(values),
+ value=math_ops.to_int64(self.default_value),
+ name='default_values'),
+ values)
+
+ return sparse_tensor_lib.SparseTensor(
+ indices=input_tensor.indices,
+ values=values,
+ dense_shape=input_tensor.dense_shape)
+
+ @property
+ def num_buckets(self):
+ """Returns number of buckets in this sparse feature."""
+ return self.number_buckets
+
+ def get_sparse_tensors(self, transformation_cache, state_manager):
+ """See `CategoricalColumn` base class."""
+ return CategoricalColumn.IdWeightPair(
+ transformation_cache.get(self, state_manager), None)
+
+
+class WeightedCategoricalColumn(
+ CategoricalColumn,
+ collections.namedtuple(
+ 'WeightedCategoricalColumn',
+ ('categorical_column', 'weight_feature_key', 'dtype'))):
+ """See `weighted_categorical_column`."""
+
+ @property
+ def name(self):
+ """See `FeatureColumn` base class."""
+ return '{}_weighted_by_{}'.format(
+ self.categorical_column.name, self.weight_feature_key)
+
+ @property
+ def parse_example_spec(self):
+ """See `FeatureColumn` base class."""
+ config = self.categorical_column.parse_example_spec
+ if self.weight_feature_key in config:
+ raise ValueError('Parse config {} already exists for {}.'.format(
+ config[self.weight_feature_key], self.weight_feature_key))
+ config[self.weight_feature_key] = parsing_ops.VarLenFeature(self.dtype)
+ return config
+
+ @property
+ def num_buckets(self):
+ """See `DenseColumn` base class."""
+ return self.categorical_column.num_buckets
+
+ def transform_feature(self, transformation_cache, state_manager):
+ """Applies weights to tensor generated from `categorical_column`'."""
+ weight_tensor = transformation_cache.get(self.weight_feature_key,
+ state_manager)
+ if weight_tensor is None:
+ raise ValueError('Missing weights {}.'.format(self.weight_feature_key))
+ weight_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(
+ weight_tensor)
+ if self.dtype != weight_tensor.dtype.base_dtype:
+ raise ValueError('Bad dtype, expected {}, but got {}.'.format(
+ self.dtype, weight_tensor.dtype))
+ if not isinstance(weight_tensor, sparse_tensor_lib.SparseTensor):
+ # The weight tensor can be a regular Tensor. In this case, sparsify it.
+ weight_tensor = _to_sparse_input_and_drop_ignore_values(
+ weight_tensor, ignore_value=0.0)
+ if not weight_tensor.dtype.is_floating:
+ weight_tensor = math_ops.to_float(weight_tensor)
+ return (transformation_cache.get(self.categorical_column, state_manager),
+ weight_tensor)
+
+ def get_sparse_tensors(self, transformation_cache, state_manager):
+ """See `CategoricalColumn` base class."""
+ tensors = transformation_cache.get(self, state_manager)
+ return CategoricalColumn.IdWeightPair(tensors[0], tensors[1])
+
+
+class CrossedColumn(
+ CategoricalColumn,
+ collections.namedtuple('CrossedColumn',
+ ('keys', 'hash_bucket_size', 'hash_key'))):
+ """See `crossed_column`."""
+
+ @property
+ def name(self):
+ """See `FeatureColumn` base class."""
+ feature_names = []
+ for key in _collect_leaf_level_keys(self):
+ if isinstance(key, FeatureColumn):
+ feature_names.append(key.name)
+ else: # key must be a string
+ feature_names.append(key)
+ return '_X_'.join(sorted(feature_names))
+
+ @property
+ def parse_example_spec(self):
+ """See `FeatureColumn` base class."""
+ config = {}
+ for key in self.keys:
+ if isinstance(key, FeatureColumn):
+ config.update(key.parse_example_spec)
+ else: # key must be a string
+ config.update({key: parsing_ops.VarLenFeature(dtypes.string)})
+ return config
+
+ def transform_feature(self, transformation_cache, state_manager):
+ """Generates a hashed sparse cross from the input tensors."""
+ feature_tensors = []
+ for key in _collect_leaf_level_keys(self):
+ if isinstance(key, six.string_types):
+ feature_tensors.append(transformation_cache.get(key, state_manager))
+ elif isinstance(key, CategoricalColumn):
+ ids_and_weights = key.get_sparse_tensors(transformation_cache,
+ state_manager)
+ if ids_and_weights.weight_tensor is not None:
+ raise ValueError(
+ 'crossed_column does not support weight_tensor, but the given '
+ 'column populates weight_tensor. '
+ 'Given column: {}'.format(key.name))
+ feature_tensors.append(ids_and_weights.id_tensor)
+ else:
+ raise ValueError('Unsupported column type. Given: {}'.format(key))
+ return sparse_ops.sparse_cross_hashed(
+ inputs=feature_tensors,
+ num_buckets=self.hash_bucket_size,
+ hash_key=self.hash_key)
+
+ @property
+ def num_buckets(self):
+ """Returns number of buckets in this sparse feature."""
+ return self.hash_bucket_size
+
+ def get_sparse_tensors(self, transformation_cache, state_manager):
+ """See `CategoricalColumn` base class."""
+ return CategoricalColumn.IdWeightPair(
+ transformation_cache.get(self, state_manager), None)
+
+
+def _collect_leaf_level_keys(cross):
+ """Collects base keys by expanding all nested crosses.
+
+ Args:
+ cross: A `CrossedColumn`.
+
+ Returns:
+ A list of strings or `CategoricalColumn` instances.
+ """
+ leaf_level_keys = []
+ for k in cross.keys:
+ if isinstance(k, CrossedColumn):
+ leaf_level_keys.extend(_collect_leaf_level_keys(k))
+ else:
+ leaf_level_keys.append(k)
+ return leaf_level_keys
+
+
+# TODO(zakaria): Move this to embedding_ops and make it public.
+def _safe_embedding_lookup_sparse(embedding_weights,
+ sparse_ids,
+ sparse_weights=None,
+ combiner='mean',
+ default_id=None,
+ name=None,
+ partition_strategy='div',
+ max_norm=None):
+ """Lookup embedding results, accounting for invalid IDs and empty features.
+
+ The partitioned embedding in `embedding_weights` must all be the same shape
+ except for the first dimension. The first dimension is allowed to vary as the
+ vocabulary size is not necessarily a multiple of `P`. `embedding_weights`
+ may be a `PartitionedVariable` as returned by using `tf.get_variable()` with a
+ partitioner.
+
+ Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs
+ with non-positive weight. For an entry with no features, the embedding vector
+ for `default_id` is returned, or the 0-vector if `default_id` is not supplied.
+
+ The ids and weights may be multi-dimensional. Embeddings are always aggregated
+ along the last dimension.
+
+ Args:
+ embedding_weights: A list of `P` float `Tensor`s or values representing
+ partitioned embedding `Tensor`s. Alternatively, a `PartitionedVariable`
+ created by partitioning along dimension 0. The total unpartitioned
+ shape should be `[e_0, e_1, ..., e_m]`, where `e_0` represents the
+ vocab size and `e_1, ..., e_m` are the embedding dimensions.
+ sparse_ids: `SparseTensor` of shape `[d_0, d_1, ..., d_n]` containing the
+ ids. `d_0` is typically batch size.
+ sparse_weights: `SparseTensor` of same shape as `sparse_ids`, containing
+ float weights corresponding to `sparse_ids`, or `None` if all weights
+ are be assumed to be 1.0.
+ combiner: A string specifying how to combine embedding results for each
+ entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean"
+ the default.
+ default_id: The id to use for an entry with no features.
+ name: A name for this operation (optional).
+ partition_strategy: A string specifying the partitioning strategy.
+ Currently `"div"` and `"mod"` are supported. Default is `"div"`.
+ max_norm: If not `None`, all embeddings are l2-normalized to max_norm before
+ combining.
+
+
+ Returns:
+ Dense `Tensor` of shape `[d_0, d_1, ..., d_{n-1}, e_1, ..., e_m]`.
+
+ Raises:
+ ValueError: if `embedding_weights` is empty.
+ """
+ if embedding_weights is None:
+ raise ValueError('Missing embedding_weights %s.' % embedding_weights)
+ if isinstance(embedding_weights, variables.PartitionedVariable):
+ embedding_weights = list(embedding_weights) # get underlying Variables.
+ if not isinstance(embedding_weights, list):
+ embedding_weights = [embedding_weights]
+ if len(embedding_weights) < 1:
+ raise ValueError('Missing embedding_weights %s.' % embedding_weights)
+
+ dtype = sparse_weights.dtype if sparse_weights is not None else None
+ embedding_weights = [
+ ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
+ ]
+
+ with ops.name_scope(name, 'embedding_lookup',
+ embedding_weights + [sparse_ids,
+ sparse_weights]) as scope:
+ # Reshape higher-rank sparse ids and weights to linear segment ids.
+ original_shape = sparse_ids.dense_shape
+ original_rank_dim = sparse_ids.dense_shape.get_shape()[0]
+ original_rank = (
+ array_ops.size(original_shape)
+ if original_rank_dim.value is None
+ else original_rank_dim.value)
+ sparse_ids = sparse_ops.sparse_reshape(sparse_ids, [
+ math_ops.reduce_prod(
+ array_ops.slice(original_shape, [0], [original_rank - 1])),
+ array_ops.gather(original_shape, original_rank - 1)])
+ if sparse_weights is not None:
+ sparse_weights = sparse_tensor_lib.SparseTensor(
+ sparse_ids.indices,
+ sparse_weights.values, sparse_ids.dense_shape)
+
+ # Prune invalid ids and weights.
+ sparse_ids, sparse_weights = _prune_invalid_ids(sparse_ids, sparse_weights)
+ if combiner != 'sum':
+ sparse_ids, sparse_weights = _prune_invalid_weights(
+ sparse_ids, sparse_weights)
+
+ # Fill in dummy values for empty features, if necessary.
+ sparse_ids, is_row_empty = sparse_ops.sparse_fill_empty_rows(sparse_ids,
+ default_id or
+ 0)
+ if sparse_weights is not None:
+ sparse_weights, _ = sparse_ops.sparse_fill_empty_rows(sparse_weights, 1.0)
+
+ result = embedding_ops.embedding_lookup_sparse(
+ embedding_weights,
+ sparse_ids,
+ sparse_weights,
+ combiner=combiner,
+ partition_strategy=partition_strategy,
+ name=None if default_id is None else scope,
+ max_norm=max_norm)
+
+ if default_id is None:
+ # Broadcast is_row_empty to the same shape as embedding_lookup_result,
+ # for use in Select.
+ is_row_empty = array_ops.tile(
+ array_ops.reshape(is_row_empty, [-1, 1]),
+ array_ops.stack([1, array_ops.shape(result)[1]]))
+
+ result = array_ops.where(is_row_empty,
+ array_ops.zeros_like(result),
+ result,
+ name=scope)
+
+ # Reshape back from linear ids back into higher-dimensional dense result.
+ final_result = array_ops.reshape(
+ result,
+ array_ops.concat([
+ array_ops.slice(
+ math_ops.cast(original_shape, dtypes.int32), [0],
+ [original_rank - 1]),
+ array_ops.slice(array_ops.shape(result), [1], [-1])
+ ], 0))
+ final_result.set_shape(tensor_shape.unknown_shape(
+ (original_rank_dim - 1).value).concatenate(result.get_shape()[1:]))
+ return final_result
+
+
+def _prune_invalid_ids(sparse_ids, sparse_weights):
+ """Prune invalid IDs (< 0) from the input ids and weights."""
+ is_id_valid = math_ops.greater_equal(sparse_ids.values, 0)
+ if sparse_weights is not None:
+ is_id_valid = math_ops.logical_and(
+ is_id_valid,
+ array_ops.ones_like(sparse_weights.values, dtype=dtypes.bool))
+ sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid)
+ if sparse_weights is not None:
+ sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid)
+ return sparse_ids, sparse_weights
+
+
+def _prune_invalid_weights(sparse_ids, sparse_weights):
+ """Prune invalid weights (< 0) from the input ids and weights."""
+ if sparse_weights is not None:
+ is_weights_valid = math_ops.greater(sparse_weights.values, 0)
+ sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_weights_valid)
+ sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_weights_valid)
+ return sparse_ids, sparse_weights
+
+
+class IndicatorColumn(DenseColumn, SequenceDenseColumn,
+ collections.namedtuple('IndicatorColumn',
+ ('categorical_column'))):
+ """Represents a one-hot column for use in deep networks.
+
+ Args:
+ categorical_column: A `CategoricalColumn` which is created by
+ `categorical_column_with_*` function.
+ """
+
+ @property
+ def name(self):
+ """See `FeatureColumn` base class."""
+ return '{}_indicator'.format(self.categorical_column.name)
+
+ def transform_feature(self, transformation_cache, state_manager):
+ """Returns dense `Tensor` representing feature.
+
+ Args:
+ transformation_cache: A `FeatureTransformationCache` object to access
+ features.
+ state_manager: A `StateManager` to create / access resources such as
+ lookup tables.
+
+ Returns:
+ Transformed feature `Tensor`.
+
+ Raises:
+ ValueError: if input rank is not known at graph building time.
+ """
+ id_weight_pair = self.categorical_column.get_sparse_tensors(
+ transformation_cache, state_manager)
+ id_tensor = id_weight_pair.id_tensor
+ weight_tensor = id_weight_pair.weight_tensor
+
+ # If the underlying column is weighted, return the input as a dense tensor.
+ if weight_tensor is not None:
+ weighted_column = sparse_ops.sparse_merge(
+ sp_ids=id_tensor,
+ sp_values=weight_tensor,
+ vocab_size=int(self.variable_shape[-1]))
+ # Remove (?, -1) index
+ weighted_column = sparse_ops.sparse_slice(weighted_column, [0, 0],
+ weighted_column.dense_shape)
+ return sparse_ops.sparse_tensor_to_dense(weighted_column)
+
+ dense_id_tensor = sparse_ops.sparse_tensor_to_dense(
+ id_tensor, default_value=-1)
+
+ # One hot must be float for tf.concat reasons since all other inputs to
+ # input_layer are float32.
+ one_hot_id_tensor = array_ops.one_hot(
+ dense_id_tensor,
+ depth=self.variable_shape[-1],
+ on_value=1.0,
+ off_value=0.0)
+
+ # Reduce to get a multi-hot per example.
+ return math_ops.reduce_sum(one_hot_id_tensor, axis=[-2])
+
+ @property
+ def parse_example_spec(self):
+ """See `FeatureColumn` base class."""
+ return self.categorical_column.parse_example_spec
+
+ @property
+ def variable_shape(self):
+ """Returns a `TensorShape` representing the shape of the dense `Tensor`."""
+ return tensor_shape.TensorShape([1, self.categorical_column.num_buckets])
+
+ def get_dense_tensor(self, transformation_cache, state_manager):
+ """Returns dense `Tensor` representing feature.
+
+ Args:
+ transformation_cache: A `FeatureTransformationCache` object to access
+ features.
+ state_manager: A `StateManager` to create / access resources such as
+ lookup tables.
+
+ Returns:
+ Dense `Tensor` created within `transform_feature`.
+
+ Raises:
+ ValueError: If `categorical_column` is a `SequenceCategoricalColumn`.
+ """
+ if isinstance(self.categorical_column, SequenceCategoricalColumn):
+ raise ValueError(
+ 'In indicator_column: {}. '
+ 'categorical_column must not be of type SequenceCategoricalColumn. '
+ 'Suggested fix A: If you wish to use input_layer, use a '
+ 'non-sequence categorical_column_with_*. '
+ 'Suggested fix B: If you wish to create sequence input, use '
+ 'sequence_input_layer instead of input_layer. '
+ 'Given (type {}): {}'.format(self.name, type(self.categorical_column),
+ self.categorical_column))
+ # Feature has been already transformed. Return the intermediate
+ # representation created by transform_feature.
+ return transformation_cache.get(self, state_manager)
+
+ def get_sequence_dense_tensor(self, transformation_cache, state_manager):
+ """See `SequenceDenseColumn` base class."""
+ if not isinstance(self.categorical_column, SequenceCategoricalColumn):
+ raise ValueError(
+ 'In indicator_column: {}. '
+ 'categorical_column must be of type SequenceCategoricalColumn '
+ 'to use sequence_input_layer. '
+ 'Suggested fix: Use one of sequence_categorical_column_with_*. '
+ 'Given (type {}): {}'.format(self.name, type(self.categorical_column),
+ self.categorical_column))
+ # Feature has been already transformed. Return the intermediate
+ # representation created by transform_feature.
+ dense_tensor = transformation_cache.get(self, state_manager)
+ sparse_tensors = self.categorical_column.get_sparse_tensors(
+ transformation_cache, state_manager)
+ sequence_length = _sequence_length_from_sparse_tensor(
+ sparse_tensors.id_tensor)
+ return SequenceDenseColumn.TensorSequenceLengthPair(
+ dense_tensor=dense_tensor, sequence_length=sequence_length)
+
+
+def _verify_static_batch_size_equality(tensors, columns):
+ # bath_size is a tf.Dimension object.
+ expected_batch_size = None
+ for i in range(0, len(tensors)):
+ if tensors[i].shape[0].value is not None:
+ if expected_batch_size is None:
+ bath_size_column_index = i
+ expected_batch_size = tensors[i].shape[0]
+ elif not expected_batch_size.is_compatible_with(tensors[i].shape[0]):
+ raise ValueError(
+ 'Batch size (first dimension) of each feature must be same. '
+ 'Batch size of columns ({}, {}): ({}, {})'.format(
+ columns[bath_size_column_index].name, columns[i].name,
+ expected_batch_size, tensors[i].shape[0]))
+
+
+def _sequence_length_from_sparse_tensor(sp_tensor, num_elements=1):
+ """Returns a [batch_size] Tensor with per-example sequence length."""
+ with ops.name_scope(None, 'sequence_length') as name_scope:
+ row_ids = sp_tensor.indices[:, 0]
+ column_ids = sp_tensor.indices[:, 1]
+ column_ids += array_ops.ones_like(column_ids)
+ seq_length = math_ops.to_int64(
+ math_ops.segment_max(column_ids, segment_ids=row_ids) / num_elements)
+ # If the last n rows do not have ids, seq_length will have shape
+ # [batch_size - n]. Pad the remaining values with zeros.
+ n_pad = array_ops.shape(sp_tensor)[:1] - array_ops.shape(seq_length)[:1]
+ padding = array_ops.zeros(n_pad, dtype=seq_length.dtype)
+ return array_ops.concat([seq_length, padding], axis=0, name=name_scope)
+
+
+class SequenceCategoricalColumn(FeatureColumn,
+ collections.namedtuple(
+ 'SequenceCategoricalColumn',
+ ('categorical_column'))):
+ """Represents sequences of categorical data."""
+
+ @property
+ def name(self):
+ """See `FeatureColumn` base class."""
+ return self.categorical_column.name
+
+ @property
+ def parse_example_spec(self):
+ """See `FeatureColumn` base class."""
+ return self.categorical_column.parse_example_spec
+
+ def transform_feature(self, transformation_cache, state_manager):
+ """See `FeatureColumn` base class."""
+ return self.categorical_column.transform_feature(transformation_cache,
+ state_manager)
+
+ @property
+ def num_buckets(self):
+ """Returns number of buckets in this sparse feature."""
+ return self.categorical_column.num_buckets
+
+ def get_sequence_sparse_tensors(self, transformation_cache, state_manager):
+ """Returns an IdWeightPair.
+
+ `IdWeightPair` is a pair of `SparseTensor`s which represents ids and
+ weights.
+
+ `IdWeightPair.id_tensor` is typically a `batch_size` x `num_buckets`
+ `SparseTensor` of `int64`. `IdWeightPair.weight_tensor` is either a
+ `SparseTensor` of `float` or `None` to indicate all weights should be
+ taken to be 1. If specified, `weight_tensor` must have exactly the same
+ shape and indices as `sp_ids`. Expected `SparseTensor` is same as parsing
+ output of a `VarLenFeature` which is a ragged matrix.
+
+ Args:
+ transformation_cache: A `FeatureTransformationCache` object to access
+ features.
+ state_manager: A `StateManager` to create / access resources such as
+ lookup tables.
+ """
+ sparse_tensors = self.categorical_column.get_sparse_tensors(
+ transformation_cache, state_manager)
+ id_tensor = sparse_tensors.id_tensor
+ weight_tensor = sparse_tensors.weight_tensor
+ # Expands final dimension, so that embeddings are not combined during
+ # embedding lookup.
+ check_id_rank = check_ops.assert_equal(
+ array_ops.rank(id_tensor), 2,
+ data=[
+ 'Column {} expected ID tensor of rank 2. '.format(self.name),
+ 'id_tensor shape: ', array_ops.shape(id_tensor)])
+ with ops.control_dependencies([check_id_rank]):
+ id_tensor = sparse_ops.sparse_reshape(
+ id_tensor,
+ shape=array_ops.concat([id_tensor.dense_shape, [1]], axis=0))
+ if weight_tensor is not None:
+ check_weight_rank = check_ops.assert_equal(
+ array_ops.rank(weight_tensor), 2,
+ data=[
+ 'Column {} expected weight tensor of rank 2.'.format(self.name),
+ 'weight_tensor shape:', array_ops.shape(weight_tensor)])
+ with ops.control_dependencies([check_weight_rank]):
+ weight_tensor = sparse_ops.sparse_reshape(
+ weight_tensor,
+ shape=array_ops.concat([weight_tensor.dense_shape, [1]], axis=0))
+ return CategoricalColumn.IdWeightPair(id_tensor, weight_tensor)
diff --git a/tensorflow/python/feature_column/feature_column_v2_test.py b/tensorflow/python/feature_column/feature_column_v2_test.py
new file mode 100644
index 0000000000..80a9d5d40e
--- /dev/null
+++ b/tensorflow/python/feature_column/feature_column_v2_test.py
@@ -0,0 +1,6583 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for feature_column."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import copy
+
+import numpy as np
+
+from tensorflow.core.example import example_pb2
+from tensorflow.core.example import feature_pb2
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python.client import session
+from tensorflow.python.eager import backprop
+from tensorflow.python.eager import context
+from tensorflow.python.estimator.inputs import numpy_io
+from tensorflow.python.feature_column import feature_column as fc_old
+from tensorflow.python.feature_column import feature_column_v2 as fc
+from tensorflow.python.feature_column.feature_column_v2 import FeatureColumn
+from tensorflow.python.feature_column.feature_column_v2 import FeatureTransformationCache
+from tensorflow.python.feature_column.feature_column_v2 import InputLayer
+from tensorflow.python.feature_column.feature_column_v2 import StateManager
+from tensorflow.python.feature_column.feature_column_v2 import _LinearModel
+from tensorflow.python.feature_column.feature_column_v2 import _transform_features
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import lookup_ops
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.ops import partitioned_variables
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables as variables_lib
+from tensorflow.python.platform import test
+from tensorflow.python.training import coordinator
+from tensorflow.python.training import queue_runner_impl
+
+
+def _initialized_session(config=None):
+ sess = session.Session(config=config)
+ sess.run(variables_lib.global_variables_initializer())
+ sess.run(lookup_ops.tables_initializer())
+ return sess
+
+
+class LazyColumnTest(test.TestCase):
+
+ def test_transformations_called_once(self):
+
+ class TransformCounter(FeatureColumn):
+
+ def __init__(self):
+ self.num_transform = 0
+
+ @property
+ def name(self):
+ return 'TransformCounter'
+
+ def transform_feature(self, transformation_cache, state_manager):
+ self.num_transform += 1 # Count transform calls.
+ return transformation_cache.get('a', state_manager)
+
+ @property
+ def parse_example_spec(self):
+ pass
+
+ transformation_cache = FeatureTransformationCache(
+ features={'a': [[2], [3.]]})
+ column = TransformCounter()
+ self.assertEqual(0, column.num_transform)
+ transformation_cache.get(column, None)
+ self.assertEqual(1, column.num_transform)
+ transformation_cache.get(column, None)
+ self.assertEqual(1, column.num_transform)
+
+ def test_returns_transform_output(self):
+
+ class Transformer(FeatureColumn):
+
+ @property
+ def name(self):
+ return 'Transformer'
+
+ def transform_feature(self, transformation_cache, state_manager):
+ return 'Output'
+
+ @property
+ def parse_example_spec(self):
+ pass
+
+ transformation_cache = FeatureTransformationCache(
+ features={'a': [[2], [3.]]})
+ column = Transformer()
+ self.assertEqual('Output', transformation_cache.get(column, None))
+ self.assertEqual('Output', transformation_cache.get(column, None))
+
+ def test_does_not_pollute_given_features_dict(self):
+
+ class Transformer(FeatureColumn):
+
+ @property
+ def name(self):
+ return 'Transformer'
+
+ def transform_feature(self, transformation_cache, state_manager):
+ return 'Output'
+
+ @property
+ def parse_example_spec(self):
+ pass
+
+ features = {'a': [[2], [3.]]}
+ transformation_cache = FeatureTransformationCache(features=features)
+ transformation_cache.get(Transformer(), None)
+ self.assertEqual(['a'], list(features.keys()))
+
+ def test_error_if_feature_is_not_found(self):
+ transformation_cache = FeatureTransformationCache(
+ features={'a': [[2], [3.]]})
+ with self.assertRaisesRegexp(ValueError,
+ 'bbb is not in features dictionary'):
+ transformation_cache.get('bbb', None)
+ with self.assertRaisesRegexp(ValueError,
+ 'bbb is not in features dictionary'):
+ transformation_cache.get(u'bbb', None)
+
+ def test_not_supported_feature_column(self):
+
+ class NotAProperColumn(FeatureColumn):
+
+ @property
+ def name(self):
+ return 'NotAProperColumn'
+
+ def transform_feature(self, transformation_cache, state_manager):
+ # It should return not None.
+ pass
+
+ @property
+ def parse_example_spec(self):
+ pass
+
+ transformation_cache = FeatureTransformationCache(
+ features={'a': [[2], [3.]]})
+ with self.assertRaisesRegexp(ValueError,
+ 'NotAProperColumn is not supported'):
+ transformation_cache.get(NotAProperColumn(), None)
+
+ def test_key_should_be_string_or_feature_colum(self):
+
+ class NotAFeatureColumn(object):
+ pass
+
+ transformation_cache = FeatureTransformationCache(
+ features={'a': [[2], [3.]]})
+ with self.assertRaisesRegexp(
+ TypeError, '"key" must be either a "str" or "FeatureColumn".'):
+ transformation_cache.get(NotAFeatureColumn(), None)
+
+
+class NumericColumnTest(test.TestCase):
+
+ def test_defaults(self):
+ a = fc.numeric_column('aaa')
+ self.assertEqual('aaa', a.key)
+ self.assertEqual('aaa', a.name)
+ self.assertEqual((1,), a.shape)
+ self.assertIsNone(a.default_value)
+ self.assertEqual(dtypes.float32, a.dtype)
+ self.assertIsNone(a.normalizer_fn)
+
+ def test_key_should_be_string(self):
+ with self.assertRaisesRegexp(ValueError, 'key must be a string.'):
+ fc.numeric_column(key=('aaa',))
+
+ def test_shape_saved_as_tuple(self):
+ a = fc.numeric_column('aaa', shape=[1, 2], default_value=[[3, 2.]])
+ self.assertEqual((1, 2), a.shape)
+
+ def test_default_value_saved_as_tuple(self):
+ a = fc.numeric_column('aaa', default_value=4.)
+ self.assertEqual((4.,), a.default_value)
+ a = fc.numeric_column('aaa', shape=[1, 2], default_value=[[3, 2.]])
+ self.assertEqual(((3., 2.),), a.default_value)
+
+ def test_shape_and_default_value_compatibility(self):
+ fc.numeric_column('aaa', shape=[2], default_value=[1, 2.])
+ with self.assertRaisesRegexp(ValueError, 'The shape of default_value'):
+ fc.numeric_column('aaa', shape=[2], default_value=[1, 2, 3.])
+ fc.numeric_column(
+ 'aaa', shape=[3, 2], default_value=[[2, 3], [1, 2], [2, 3.]])
+ with self.assertRaisesRegexp(ValueError, 'The shape of default_value'):
+ fc.numeric_column(
+ 'aaa', shape=[3, 1], default_value=[[2, 3], [1, 2], [2, 3.]])
+ with self.assertRaisesRegexp(ValueError, 'The shape of default_value'):
+ fc.numeric_column(
+ 'aaa', shape=[3, 3], default_value=[[2, 3], [1, 2], [2, 3.]])
+
+ def test_default_value_type_check(self):
+ fc.numeric_column(
+ 'aaa', shape=[2], default_value=[1, 2.], dtype=dtypes.float32)
+ fc.numeric_column(
+ 'aaa', shape=[2], default_value=[1, 2], dtype=dtypes.int32)
+ with self.assertRaisesRegexp(TypeError, 'must be compatible with dtype'):
+ fc.numeric_column(
+ 'aaa', shape=[2], default_value=[1, 2.], dtype=dtypes.int32)
+ with self.assertRaisesRegexp(TypeError,
+ 'default_value must be compatible with dtype'):
+ fc.numeric_column('aaa', default_value=['string'])
+
+ def test_shape_must_be_positive_integer(self):
+ with self.assertRaisesRegexp(TypeError, 'shape dimensions must be integer'):
+ fc.numeric_column(
+ 'aaa', shape=[
+ 1.0,
+ ])
+
+ with self.assertRaisesRegexp(ValueError,
+ 'shape dimensions must be greater than 0'):
+ fc.numeric_column(
+ 'aaa', shape=[
+ 0,
+ ])
+
+ def test_dtype_is_convertible_to_float(self):
+ with self.assertRaisesRegexp(ValueError,
+ 'dtype must be convertible to float'):
+ fc.numeric_column('aaa', dtype=dtypes.string)
+
+ def test_scalar_default_value_fills_the_shape(self):
+ a = fc.numeric_column('aaa', shape=[2, 3], default_value=2.)
+ self.assertEqual(((2., 2., 2.), (2., 2., 2.)), a.default_value)
+
+ def test_parse_spec(self):
+ a = fc.numeric_column('aaa', shape=[2, 3], dtype=dtypes.int32)
+ self.assertEqual({
+ 'aaa': parsing_ops.FixedLenFeature((2, 3), dtype=dtypes.int32)
+ }, a.parse_example_spec)
+
+ def test_parse_example_no_default_value(self):
+ price = fc.numeric_column('price', shape=[2])
+ data = example_pb2.Example(features=feature_pb2.Features(
+ feature={
+ 'price':
+ feature_pb2.Feature(float_list=feature_pb2.FloatList(
+ value=[20., 110.]))
+ }))
+ features = parsing_ops.parse_example(
+ serialized=[data.SerializeToString()],
+ features=fc.make_parse_example_spec([price]))
+ self.assertIn('price', features)
+ with self.test_session():
+ self.assertAllEqual([[20., 110.]], features['price'].eval())
+
+ def test_parse_example_with_default_value(self):
+ price = fc.numeric_column('price', shape=[2], default_value=11.)
+ data = example_pb2.Example(features=feature_pb2.Features(
+ feature={
+ 'price':
+ feature_pb2.Feature(float_list=feature_pb2.FloatList(
+ value=[20., 110.]))
+ }))
+ no_data = example_pb2.Example(features=feature_pb2.Features(
+ feature={
+ 'something_else':
+ feature_pb2.Feature(float_list=feature_pb2.FloatList(
+ value=[20., 110.]))
+ }))
+ features = parsing_ops.parse_example(
+ serialized=[data.SerializeToString(),
+ no_data.SerializeToString()],
+ features=fc.make_parse_example_spec([price]))
+ self.assertIn('price', features)
+ with self.test_session():
+ self.assertAllEqual([[20., 110.], [11., 11.]], features['price'].eval())
+
+ def test_normalizer_fn_must_be_callable(self):
+ with self.assertRaisesRegexp(TypeError, 'must be a callable'):
+ fc.numeric_column('price', normalizer_fn='NotACallable')
+
+ def test_normalizer_fn_transform_feature(self):
+
+ def _increment_two(input_tensor):
+ return input_tensor + 2.
+
+ price = fc.numeric_column('price', shape=[2], normalizer_fn=_increment_two)
+ output = _transform_features({'price': [[1., 2.], [5., 6.]]}, [price], None)
+ with self.test_session():
+ self.assertAllEqual([[3., 4.], [7., 8.]], output[price].eval())
+
+ def test_get_dense_tensor(self):
+
+ def _increment_two(input_tensor):
+ return input_tensor + 2.
+
+ price = fc.numeric_column('price', shape=[2], normalizer_fn=_increment_two)
+ transformation_cache = FeatureTransformationCache({
+ 'price': [[1., 2.], [5., 6.]]
+ })
+ self.assertEqual(
+ transformation_cache.get(price, None),
+ price.get_dense_tensor(transformation_cache, None))
+
+ def test_sparse_tensor_not_supported(self):
+ price = fc.numeric_column('price')
+ transformation_cache = FeatureTransformationCache({
+ 'price':
+ sparse_tensor.SparseTensor(
+ indices=[[0, 0]], values=[0.3], dense_shape=[1, 1])
+ })
+ with self.assertRaisesRegexp(ValueError, 'must be a Tensor'):
+ price.transform_feature(transformation_cache, None)
+
+ def test_deep_copy(self):
+ a = fc.numeric_column('aaa', shape=[1, 2], default_value=[[3., 2.]])
+ a_copy = copy.deepcopy(a)
+ self.assertEqual(a_copy.name, 'aaa')
+ self.assertEqual(a_copy.shape, (1, 2))
+ self.assertEqual(a_copy.default_value, ((3., 2.),))
+
+ def test_numpy_default_value(self):
+ a = fc.numeric_column(
+ 'aaa', shape=[1, 2], default_value=np.array([[3., 2.]]))
+ self.assertEqual(a.default_value, ((3., 2.),))
+
+ def test_linear_model(self):
+ price = fc_old.numeric_column('price')
+ with ops.Graph().as_default():
+ features = {'price': [[1.], [5.]]}
+ predictions = fc.linear_model(features, [price])
+ bias = get_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ self.assertAllClose([[0.]], price_var.eval())
+ self.assertAllClose([[0.], [0.]], predictions.eval())
+ sess.run(price_var.assign([[10.]]))
+ self.assertAllClose([[10.], [50.]], predictions.eval())
+
+ def test_keras_linear_model(self):
+ price = fc_old.numeric_column('price')
+ with ops.Graph().as_default():
+ features = {'price': [[1.], [5.]]}
+ predictions = get_keras_linear_model_predictions(features, [price])
+ bias = get_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ self.assertAllClose([[0.]], price_var.eval())
+ self.assertAllClose([[0.], [0.]], predictions.eval())
+ sess.run(price_var.assign([[10.]]))
+ self.assertAllClose([[10.], [50.]], predictions.eval())
+
+
+class BucketizedColumnTest(test.TestCase):
+
+ def test_invalid_source_column_type(self):
+ a = fc.categorical_column_with_hash_bucket('aaa', hash_bucket_size=10)
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'source_column must be a column generated with numeric_column'):
+ fc.bucketized_column(a, boundaries=[0, 1])
+
+ def test_invalid_source_column_shape(self):
+ a = fc.numeric_column('aaa', shape=[2, 3])
+ with self.assertRaisesRegexp(
+ ValueError, 'source_column must be one-dimensional column'):
+ fc.bucketized_column(a, boundaries=[0, 1])
+
+ def test_invalid_boundaries(self):
+ a = fc.numeric_column('aaa')
+ with self.assertRaisesRegexp(
+ ValueError, 'boundaries must be a sorted list'):
+ fc.bucketized_column(a, boundaries=None)
+ with self.assertRaisesRegexp(
+ ValueError, 'boundaries must be a sorted list'):
+ fc.bucketized_column(a, boundaries=1.)
+ with self.assertRaisesRegexp(
+ ValueError, 'boundaries must be a sorted list'):
+ fc.bucketized_column(a, boundaries=[1, 0])
+ with self.assertRaisesRegexp(
+ ValueError, 'boundaries must be a sorted list'):
+ fc.bucketized_column(a, boundaries=[1, 1])
+
+ def test_name(self):
+ a = fc.numeric_column('aaa', dtype=dtypes.int32)
+ b = fc.bucketized_column(a, boundaries=[0, 1])
+ self.assertEqual('aaa_bucketized', b.name)
+
+ def test_parse_spec(self):
+ a = fc.numeric_column('aaa', shape=[2], dtype=dtypes.int32)
+ b = fc.bucketized_column(a, boundaries=[0, 1])
+ self.assertEqual({
+ 'aaa': parsing_ops.FixedLenFeature((2,), dtype=dtypes.int32)
+ }, b.parse_example_spec)
+
+ def test_variable_shape(self):
+ a = fc.numeric_column('aaa', shape=[2], dtype=dtypes.int32)
+ b = fc.bucketized_column(a, boundaries=[0, 1])
+ # Column 'aaa` has shape [2] times three buckets -> variable_shape=[2, 3].
+ self.assertAllEqual((2, 3), b.variable_shape)
+
+ def test_num_buckets(self):
+ a = fc.numeric_column('aaa', shape=[2], dtype=dtypes.int32)
+ b = fc.bucketized_column(a, boundaries=[0, 1])
+ # Column 'aaa` has shape [2] times three buckets -> num_buckets=6.
+ self.assertEqual(6, b.num_buckets)
+
+ def test_parse_example(self):
+ price = fc.numeric_column('price', shape=[2])
+ bucketized_price = fc.bucketized_column(price, boundaries=[0, 50])
+ data = example_pb2.Example(features=feature_pb2.Features(
+ feature={
+ 'price':
+ feature_pb2.Feature(float_list=feature_pb2.FloatList(
+ value=[20., 110.]))
+ }))
+ features = parsing_ops.parse_example(
+ serialized=[data.SerializeToString()],
+ features=fc.make_parse_example_spec([bucketized_price]))
+ self.assertIn('price', features)
+ with self.test_session():
+ self.assertAllEqual([[20., 110.]], features['price'].eval())
+
+ def test_transform_feature(self):
+ price = fc.numeric_column('price', shape=[2])
+ bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
+ with ops.Graph().as_default():
+ transformed_tensor = _transform_features({
+ 'price': [[-1., 1.], [5., 6.]]
+ }, [bucketized_price], None)
+ with _initialized_session():
+ self.assertAllEqual([[0, 1], [3, 4]],
+ transformed_tensor[bucketized_price].eval())
+
+ def test_get_dense_tensor_one_input_value(self):
+ """Tests _get_dense_tensor() for input with shape=[1]."""
+ price = fc.numeric_column('price', shape=[1])
+ bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
+ with ops.Graph().as_default():
+ transformation_cache = FeatureTransformationCache({
+ 'price': [[-1.], [1.], [5.], [6.]]
+ })
+ with _initialized_session():
+ bucketized_price_tensor = bucketized_price.get_dense_tensor(
+ transformation_cache, None)
+ self.assertAllClose(
+ # One-hot tensor.
+ [[[1., 0., 0., 0., 0.]],
+ [[0., 1., 0., 0., 0.]],
+ [[0., 0., 0., 1., 0.]],
+ [[0., 0., 0., 0., 1.]]],
+ bucketized_price_tensor.eval())
+
+ def test_get_dense_tensor_two_input_values(self):
+ """Tests _get_dense_tensor() for input with shape=[2]."""
+ price = fc.numeric_column('price', shape=[2])
+ bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
+ with ops.Graph().as_default():
+ transformation_cache = FeatureTransformationCache({
+ 'price': [[-1., 1.], [5., 6.]]
+ })
+ with _initialized_session():
+ bucketized_price_tensor = bucketized_price.get_dense_tensor(
+ transformation_cache, None)
+ self.assertAllClose(
+ # One-hot tensor.
+ [[[1., 0., 0., 0., 0.], [0., 1., 0., 0., 0.]],
+ [[0., 0., 0., 1., 0.], [0., 0., 0., 0., 1.]]],
+ bucketized_price_tensor.eval())
+
+ def test_get_sparse_tensors_one_input_value(self):
+ """Tests _get_sparse_tensors() for input with shape=[1]."""
+ price = fc.numeric_column('price', shape=[1])
+ bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
+ with ops.Graph().as_default():
+ transformation_cache = FeatureTransformationCache({
+ 'price': [[-1.], [1.], [5.], [6.]]
+ })
+ with _initialized_session() as sess:
+ id_weight_pair = bucketized_price.get_sparse_tensors(
+ transformation_cache, None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ id_tensor_value = sess.run(id_weight_pair.id_tensor)
+ self.assertAllEqual(
+ [[0, 0], [1, 0], [2, 0], [3, 0]], id_tensor_value.indices)
+ self.assertAllEqual([0, 1, 3, 4], id_tensor_value.values)
+ self.assertAllEqual([4, 1], id_tensor_value.dense_shape)
+
+ def test_get_sparse_tensors_two_input_values(self):
+ """Tests _get_sparse_tensors() for input with shape=[2]."""
+ price = fc.numeric_column('price', shape=[2])
+ bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
+ with ops.Graph().as_default():
+ transformation_cache = FeatureTransformationCache({
+ 'price': [[-1., 1.], [5., 6.]]
+ })
+ with _initialized_session() as sess:
+ id_weight_pair = bucketized_price.get_sparse_tensors(
+ transformation_cache, None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ id_tensor_value = sess.run(id_weight_pair.id_tensor)
+ self.assertAllEqual(
+ [[0, 0], [0, 1], [1, 0], [1, 1]], id_tensor_value.indices)
+ # Values 0-4 correspond to the first column of the input price.
+ # Values 5-9 correspond to the second column of the input price.
+ self.assertAllEqual([0, 6, 3, 9], id_tensor_value.values)
+ self.assertAllEqual([2, 2], id_tensor_value.dense_shape)
+
+ def test_sparse_tensor_input_not_supported(self):
+ price = fc.numeric_column('price')
+ bucketized_price = fc.bucketized_column(price, boundaries=[0, 1])
+ transformation_cache = FeatureTransformationCache({
+ 'price':
+ sparse_tensor.SparseTensor(
+ indices=[[0, 0]], values=[0.3], dense_shape=[1, 1])
+ })
+ with self.assertRaisesRegexp(ValueError, 'must be a Tensor'):
+ bucketized_price.transform_feature(transformation_cache, None)
+
+ def test_deep_copy(self):
+ a = fc.numeric_column('aaa', shape=[2])
+ a_bucketized = fc.bucketized_column(a, boundaries=[0, 1])
+ a_bucketized_copy = copy.deepcopy(a_bucketized)
+ self.assertEqual(a_bucketized_copy.name, 'aaa_bucketized')
+ self.assertAllEqual(a_bucketized_copy.variable_shape, (2, 3))
+ self.assertEqual(a_bucketized_copy.boundaries, (0, 1))
+
+ def test_linear_model_one_input_value(self):
+ """Tests linear_model() for input with shape=[1]."""
+ price = fc_old.numeric_column('price', shape=[1])
+ bucketized_price = fc_old.bucketized_column(price, boundaries=[0, 2, 4, 6])
+ with ops.Graph().as_default():
+ features = {'price': [[-1.], [1.], [5.], [6.]]}
+ predictions = fc.linear_model(features, [bucketized_price])
+ bias = get_linear_model_bias()
+ bucketized_price_var = get_linear_model_column_var(bucketized_price)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ # One weight variable per bucket, all initialized to zero.
+ self.assertAllClose(
+ [[0.], [0.], [0.], [0.], [0.]], bucketized_price_var.eval())
+ self.assertAllClose([[0.], [0.], [0.], [0.]], predictions.eval())
+ sess.run(bucketized_price_var.assign(
+ [[10.], [20.], [30.], [40.], [50.]]))
+ # price -1. is in the 0th bucket, whose weight is 10.
+ # price 1. is in the 1st bucket, whose weight is 20.
+ # price 5. is in the 3rd bucket, whose weight is 40.
+ # price 6. is in the 4th bucket, whose weight is 50.
+ self.assertAllClose([[10.], [20.], [40.], [50.]], predictions.eval())
+ sess.run(bias.assign([1.]))
+ self.assertAllClose([[11.], [21.], [41.], [51.]], predictions.eval())
+
+ def test_linear_model_two_input_values(self):
+ """Tests linear_model() for input with shape=[2]."""
+ price = fc_old.numeric_column('price', shape=[2])
+ bucketized_price = fc_old.bucketized_column(price, boundaries=[0, 2, 4, 6])
+ with ops.Graph().as_default():
+ features = {'price': [[-1., 1.], [5., 6.]]}
+ predictions = fc.linear_model(features, [bucketized_price])
+ bias = get_linear_model_bias()
+ bucketized_price_var = get_linear_model_column_var(bucketized_price)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ # One weight per bucket per input column, all initialized to zero.
+ self.assertAllClose(
+ [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]],
+ bucketized_price_var.eval())
+ self.assertAllClose([[0.], [0.]], predictions.eval())
+ sess.run(bucketized_price_var.assign(
+ [[10.], [20.], [30.], [40.], [50.],
+ [60.], [70.], [80.], [90.], [100.]]))
+ # 1st example:
+ # price -1. is in the 0th bucket, whose weight is 10.
+ # price 1. is in the 6th bucket, whose weight is 70.
+ # 2nd example:
+ # price 5. is in the 3rd bucket, whose weight is 40.
+ # price 6. is in the 9th bucket, whose weight is 100.
+ self.assertAllClose([[80.], [140.]], predictions.eval())
+ sess.run(bias.assign([1.]))
+ self.assertAllClose([[81.], [141.]], predictions.eval())
+
+ def test_keras_linear_model_one_input_value(self):
+ """Tests _LinearModel for input with shape=[1]."""
+ price = fc_old.numeric_column('price', shape=[1])
+ bucketized_price = fc_old.bucketized_column(price, boundaries=[0, 2, 4, 6])
+ with ops.Graph().as_default():
+ features = {'price': [[-1.], [1.], [5.], [6.]]}
+ predictions = get_keras_linear_model_predictions(features,
+ [bucketized_price])
+ bias = get_linear_model_bias()
+ bucketized_price_var = get_linear_model_column_var(bucketized_price)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ # One weight variable per bucket, all initialized to zero.
+ self.assertAllClose([[0.], [0.], [0.], [0.], [0.]],
+ bucketized_price_var.eval())
+ self.assertAllClose([[0.], [0.], [0.], [0.]], predictions.eval())
+ sess.run(
+ bucketized_price_var.assign([[10.], [20.], [30.], [40.], [50.]]))
+ # price -1. is in the 0th bucket, whose weight is 10.
+ # price 1. is in the 1st bucket, whose weight is 20.
+ # price 5. is in the 3rd bucket, whose weight is 40.
+ # price 6. is in the 4th bucket, whose weight is 50.
+ self.assertAllClose([[10.], [20.], [40.], [50.]], predictions.eval())
+ sess.run(bias.assign([1.]))
+ self.assertAllClose([[11.], [21.], [41.], [51.]], predictions.eval())
+
+ def test_keras_linear_model_two_input_values(self):
+ """Tests _LinearModel for input with shape=[2]."""
+ price = fc_old.numeric_column('price', shape=[2])
+ bucketized_price = fc_old.bucketized_column(price, boundaries=[0, 2, 4, 6])
+ with ops.Graph().as_default():
+ features = {'price': [[-1., 1.], [5., 6.]]}
+ predictions = get_keras_linear_model_predictions(features,
+ [bucketized_price])
+ bias = get_linear_model_bias()
+ bucketized_price_var = get_linear_model_column_var(bucketized_price)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ # One weight per bucket per input column, all initialized to zero.
+ self.assertAllClose(
+ [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]],
+ bucketized_price_var.eval())
+ self.assertAllClose([[0.], [0.]], predictions.eval())
+ sess.run(
+ bucketized_price_var.assign([[10.], [20.], [30.], [40.], [50.],
+ [60.], [70.], [80.], [90.], [100.]]))
+ # 1st example:
+ # price -1. is in the 0th bucket, whose weight is 10.
+ # price 1. is in the 6th bucket, whose weight is 70.
+ # 2nd example:
+ # price 5. is in the 3rd bucket, whose weight is 40.
+ # price 6. is in the 9th bucket, whose weight is 100.
+ self.assertAllClose([[80.], [140.]], predictions.eval())
+ sess.run(bias.assign([1.]))
+ self.assertAllClose([[81.], [141.]], predictions.eval())
+
+
+class HashedCategoricalColumnTest(test.TestCase):
+
+ def test_defaults(self):
+ a = fc.categorical_column_with_hash_bucket('aaa', 10)
+ self.assertEqual('aaa', a.name)
+ self.assertEqual('aaa', a.key)
+ self.assertEqual(10, a.hash_bucket_size)
+ self.assertEqual(dtypes.string, a.dtype)
+
+ def test_key_should_be_string(self):
+ with self.assertRaisesRegexp(ValueError, 'key must be a string.'):
+ fc.categorical_column_with_hash_bucket(('key',), 10)
+
+ def test_bucket_size_should_be_given(self):
+ with self.assertRaisesRegexp(ValueError, 'hash_bucket_size must be set.'):
+ fc.categorical_column_with_hash_bucket('aaa', None)
+
+ def test_bucket_size_should_be_positive(self):
+ with self.assertRaisesRegexp(ValueError,
+ 'hash_bucket_size must be at least 1'):
+ fc.categorical_column_with_hash_bucket('aaa', 0)
+
+ def test_dtype_should_be_string_or_integer(self):
+ fc.categorical_column_with_hash_bucket('aaa', 10, dtype=dtypes.string)
+ fc.categorical_column_with_hash_bucket('aaa', 10, dtype=dtypes.int32)
+ with self.assertRaisesRegexp(ValueError, 'dtype must be string or integer'):
+ fc.categorical_column_with_hash_bucket('aaa', 10, dtype=dtypes.float32)
+
+ def test_deep_copy(self):
+ original = fc.categorical_column_with_hash_bucket('aaa', 10)
+ for column in (original, copy.deepcopy(original)):
+ self.assertEqual('aaa', column.name)
+ self.assertEqual(10, column.hash_bucket_size)
+ self.assertEqual(10, column.num_buckets)
+ self.assertEqual(dtypes.string, column.dtype)
+
+ def test_parse_spec_string(self):
+ a = fc.categorical_column_with_hash_bucket('aaa', 10)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.string)
+ }, a.parse_example_spec)
+
+ def test_parse_spec_int(self):
+ a = fc.categorical_column_with_hash_bucket('aaa', 10, dtype=dtypes.int32)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.int32)
+ }, a.parse_example_spec)
+
+ def test_parse_example(self):
+ a = fc.categorical_column_with_hash_bucket('aaa', 10)
+ data = example_pb2.Example(features=feature_pb2.Features(
+ feature={
+ 'aaa':
+ feature_pb2.Feature(bytes_list=feature_pb2.BytesList(
+ value=[b'omar', b'stringer']))
+ }))
+ features = parsing_ops.parse_example(
+ serialized=[data.SerializeToString()],
+ features=fc.make_parse_example_spec([a]))
+ self.assertIn('aaa', features)
+ with self.test_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [0, 1]],
+ values=np.array([b'omar', b'stringer'], dtype=np.object_),
+ dense_shape=[1, 2]),
+ features['aaa'].eval())
+
+ def test_strings_should_be_hashed(self):
+ hashed_sparse = fc.categorical_column_with_hash_bucket('wire', 10)
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'],
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ outputs = _transform_features({'wire': wire_tensor}, [hashed_sparse], None)
+ output = outputs[hashed_sparse]
+ # Check exact hashed output. If hashing changes this test will break.
+ expected_values = [6, 4, 1]
+ with self.test_session():
+ self.assertEqual(dtypes.int64, output.values.dtype)
+ self.assertAllEqual(expected_values, output.values.eval())
+ self.assertAllEqual(wire_tensor.indices.eval(), output.indices.eval())
+ self.assertAllEqual(wire_tensor.dense_shape.eval(),
+ output.dense_shape.eval())
+
+ def test_tensor_dtype_should_be_string_or_integer(self):
+ string_fc = fc.categorical_column_with_hash_bucket(
+ 'a_string', 10, dtype=dtypes.string)
+ int_fc = fc.categorical_column_with_hash_bucket(
+ 'a_int', 10, dtype=dtypes.int32)
+ float_fc = fc.categorical_column_with_hash_bucket(
+ 'a_float', 10, dtype=dtypes.string)
+ int_tensor = sparse_tensor.SparseTensor(
+ values=[101],
+ indices=[[0, 0]],
+ dense_shape=[1, 1])
+ string_tensor = sparse_tensor.SparseTensor(
+ values=['101'],
+ indices=[[0, 0]],
+ dense_shape=[1, 1])
+ float_tensor = sparse_tensor.SparseTensor(
+ values=[101.],
+ indices=[[0, 0]],
+ dense_shape=[1, 1])
+ transformation_cache = FeatureTransformationCache({
+ 'a_int': int_tensor,
+ 'a_string': string_tensor,
+ 'a_float': float_tensor
+ })
+ transformation_cache.get(string_fc, None)
+ transformation_cache.get(int_fc, None)
+ with self.assertRaisesRegexp(ValueError, 'dtype must be string or integer'):
+ transformation_cache.get(float_fc, None)
+
+ def test_dtype_should_match_with_tensor(self):
+ hashed_sparse = fc.categorical_column_with_hash_bucket(
+ 'wire', 10, dtype=dtypes.int64)
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
+ transformation_cache = FeatureTransformationCache({'wire': wire_tensor})
+ with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'):
+ transformation_cache.get(hashed_sparse, None)
+
+ def test_ints_should_be_hashed(self):
+ hashed_sparse = fc.categorical_column_with_hash_bucket(
+ 'wire', 10, dtype=dtypes.int64)
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=[101, 201, 301],
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ transformation_cache = FeatureTransformationCache({'wire': wire_tensor})
+ output = transformation_cache.get(hashed_sparse, None)
+ # Check exact hashed output. If hashing changes this test will break.
+ expected_values = [3, 7, 5]
+ with self.test_session():
+ self.assertAllEqual(expected_values, output.values.eval())
+
+ def test_int32_64_is_compatible(self):
+ hashed_sparse = fc.categorical_column_with_hash_bucket(
+ 'wire', 10, dtype=dtypes.int64)
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=constant_op.constant([101, 201, 301], dtype=dtypes.int32),
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ transformation_cache = FeatureTransformationCache({'wire': wire_tensor})
+ output = transformation_cache.get(hashed_sparse, None)
+ # Check exact hashed output. If hashing changes this test will break.
+ expected_values = [3, 7, 5]
+ with self.test_session():
+ self.assertAllEqual(expected_values, output.values.eval())
+
+ def test_get_sparse_tensors(self):
+ hashed_sparse = fc.categorical_column_with_hash_bucket('wire', 10)
+ transformation_cache = FeatureTransformationCache({
+ 'wire':
+ sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'],
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ })
+ id_weight_pair = hashed_sparse.get_sparse_tensors(transformation_cache,
+ None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ self.assertEqual(
+ transformation_cache.get(hashed_sparse, None), id_weight_pair.id_tensor)
+
+ def DISABLED_test_get_sparse_tensors_weight_collections(self):
+ column = fc.categorical_column_with_hash_bucket('aaa', 10)
+ inputs = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'],
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ column._get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }),
+ weight_collections=('my_weights',))
+
+ self.assertItemsEqual(
+ [], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
+ self.assertItemsEqual([], ops.get_collection('my_weights'))
+
+ def test_get_sparse_tensors_dense_input(self):
+ hashed_sparse = fc.categorical_column_with_hash_bucket('wire', 10)
+ transformation_cache = FeatureTransformationCache({
+ 'wire': (('omar', ''), ('stringer', 'marlo'))
+ })
+ id_weight_pair = hashed_sparse.get_sparse_tensors(transformation_cache,
+ None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ self.assertEqual(
+ transformation_cache.get(hashed_sparse, None), id_weight_pair.id_tensor)
+
+ def test_linear_model(self):
+ wire_column = fc_old.categorical_column_with_hash_bucket('wire', 4)
+ self.assertEqual(4, wire_column._num_buckets)
+ with ops.Graph().as_default():
+ predictions = fc.linear_model({
+ wire_column.name: sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ }, (wire_column,))
+ bias = get_linear_model_bias()
+ wire_var = get_linear_model_column_var(wire_column)
+ with _initialized_session():
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval()
+ # 'marlo' -> 3: wire_var[3] = 4
+ # 'skywalker' -> 2, 'omar' -> 2: wire_var[2] + wire_var[2] = 3+3 = 6
+ self.assertAllClose(((4.,), (6.,)), predictions.eval())
+
+ def test_keras_linear_model(self):
+ wire_column = fc_old.categorical_column_with_hash_bucket('wire', 4)
+ self.assertEqual(4, wire_column._num_buckets)
+ with ops.Graph().as_default():
+ predictions = get_keras_linear_model_predictions({
+ wire_column.name:
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ }, (wire_column,))
+ bias = get_linear_model_bias()
+ wire_var = get_linear_model_column_var(wire_column)
+ with _initialized_session():
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval()
+ # 'marlo' -> 3: wire_var[3] = 4
+ # 'skywalker' -> 2, 'omar' -> 2: wire_var[2] + wire_var[2] = 3+3 = 6
+ self.assertAllClose(((4.,), (6.,)), predictions.eval())
+
+
+class CrossedColumnTest(test.TestCase):
+
+ def test_keys_empty(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'keys must be a list with length > 1'):
+ fc.crossed_column([], 10)
+
+ def test_keys_length_one(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'keys must be a list with length > 1'):
+ fc.crossed_column(['a'], 10)
+
+ def test_key_type_unsupported(self):
+ with self.assertRaisesRegexp(ValueError, 'Unsupported key type'):
+ fc.crossed_column(['a', fc.numeric_column('c')], 10)
+
+ with self.assertRaisesRegexp(
+ ValueError, 'categorical_column_with_hash_bucket is not supported'):
+ fc.crossed_column(
+ ['a', fc.categorical_column_with_hash_bucket('c', 10)], 10)
+
+ def test_hash_bucket_size_negative(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'hash_bucket_size must be > 1'):
+ fc.crossed_column(['a', 'c'], -1)
+
+ def test_hash_bucket_size_zero(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'hash_bucket_size must be > 1'):
+ fc.crossed_column(['a', 'c'], 0)
+
+ def test_hash_bucket_size_none(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'hash_bucket_size must be > 1'):
+ fc.crossed_column(['a', 'c'], None)
+
+ def test_name(self):
+ a = fc.numeric_column('a', dtype=dtypes.int32)
+ b = fc.bucketized_column(a, boundaries=[0, 1])
+ crossed1 = fc.crossed_column(['d1', 'd2'], 10)
+
+ crossed2 = fc.crossed_column([b, 'c', crossed1], 10)
+ self.assertEqual('a_bucketized_X_c_X_d1_X_d2', crossed2.name)
+
+ def test_name_ordered_alphabetically(self):
+ """Tests that the name does not depend on the order of given columns."""
+ a = fc.numeric_column('a', dtype=dtypes.int32)
+ b = fc.bucketized_column(a, boundaries=[0, 1])
+ crossed1 = fc.crossed_column(['d1', 'd2'], 10)
+
+ crossed2 = fc.crossed_column([crossed1, 'c', b], 10)
+ self.assertEqual('a_bucketized_X_c_X_d1_X_d2', crossed2.name)
+
+ def test_name_leaf_keys_ordered_alphabetically(self):
+ """Tests that the name does not depend on the order of given columns."""
+ a = fc.numeric_column('a', dtype=dtypes.int32)
+ b = fc.bucketized_column(a, boundaries=[0, 1])
+ crossed1 = fc.crossed_column(['d2', 'c'], 10)
+
+ crossed2 = fc.crossed_column([crossed1, 'd1', b], 10)
+ self.assertEqual('a_bucketized_X_c_X_d1_X_d2', crossed2.name)
+
+ def test_parse_spec(self):
+ a = fc.numeric_column('a', shape=[2], dtype=dtypes.int32)
+ b = fc.bucketized_column(a, boundaries=[0, 1])
+ crossed = fc.crossed_column([b, 'c'], 10)
+ self.assertEqual({
+ 'a': parsing_ops.FixedLenFeature((2,), dtype=dtypes.int32),
+ 'c': parsing_ops.VarLenFeature(dtypes.string),
+ }, crossed.parse_example_spec)
+
+ def test_num_buckets(self):
+ a = fc.numeric_column('a', shape=[2], dtype=dtypes.int32)
+ b = fc.bucketized_column(a, boundaries=[0, 1])
+ crossed = fc.crossed_column([b, 'c'], 15)
+ self.assertEqual(15, crossed.num_buckets)
+
+ def test_deep_copy(self):
+ a = fc.numeric_column('a', dtype=dtypes.int32)
+ b = fc.bucketized_column(a, boundaries=[0, 1])
+ crossed1 = fc.crossed_column(['d1', 'd2'], 10)
+ crossed2 = fc.crossed_column([b, 'c', crossed1], 15, hash_key=5)
+ crossed2_copy = copy.deepcopy(crossed2)
+ self.assertEqual('a_bucketized_X_c_X_d1_X_d2', crossed2_copy.name,)
+ self.assertEqual(15, crossed2_copy.hash_bucket_size)
+ self.assertEqual(5, crossed2_copy.hash_key)
+
+ def test_parse_example(self):
+ price = fc.numeric_column('price', shape=[2])
+ bucketized_price = fc.bucketized_column(price, boundaries=[0, 50])
+ price_cross_wire = fc.crossed_column([bucketized_price, 'wire'], 10)
+ data = example_pb2.Example(features=feature_pb2.Features(
+ feature={
+ 'price':
+ feature_pb2.Feature(float_list=feature_pb2.FloatList(
+ value=[20., 110.])),
+ 'wire':
+ feature_pb2.Feature(bytes_list=feature_pb2.BytesList(
+ value=[b'omar', b'stringer'])),
+ }))
+ features = parsing_ops.parse_example(
+ serialized=[data.SerializeToString()],
+ features=fc.make_parse_example_spec([price_cross_wire]))
+ self.assertIn('price', features)
+ self.assertIn('wire', features)
+ with self.test_session():
+ self.assertAllEqual([[20., 110.]], features['price'].eval())
+ wire_sparse = features['wire']
+ self.assertAllEqual([[0, 0], [0, 1]], wire_sparse.indices.eval())
+ # Use byte constants to pass the open-source test.
+ self.assertAllEqual([b'omar', b'stringer'], wire_sparse.values.eval())
+ self.assertAllEqual([1, 2], wire_sparse.dense_shape.eval())
+
+ def test_transform_feature(self):
+ price = fc.numeric_column('price', shape=[2])
+ bucketized_price = fc.bucketized_column(price, boundaries=[0, 50])
+ hash_bucket_size = 10
+ price_cross_wire = fc.crossed_column(
+ [bucketized_price, 'wire'], hash_bucket_size)
+ features = {
+ 'price': constant_op.constant([[1., 2.], [5., 6.]]),
+ 'wire': sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'],
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2]),
+ }
+ outputs = _transform_features(features, [price_cross_wire], None)
+ output = outputs[price_cross_wire]
+ with self.test_session() as sess:
+ output_val = sess.run(output)
+ self.assertAllEqual(
+ [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]], output_val.indices)
+ for val in output_val.values:
+ self.assertIn(val, list(range(hash_bucket_size)))
+ self.assertAllEqual([2, 4], output_val.dense_shape)
+
+ def test_get_sparse_tensors(self):
+ a = fc.numeric_column('a', dtype=dtypes.int32, shape=(2,))
+ b = fc.bucketized_column(a, boundaries=(0, 1))
+ crossed1 = fc.crossed_column(['d1', 'd2'], 10)
+ crossed2 = fc.crossed_column([b, 'c', crossed1], 15, hash_key=5)
+ with ops.Graph().as_default():
+ transformation_cache = FeatureTransformationCache({
+ 'a':
+ constant_op.constant(((-1., .5), (.5, 1.))),
+ 'c':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=['cA', 'cB', 'cC'],
+ dense_shape=(2, 2)),
+ 'd1':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=['d1A', 'd1B', 'd1C'],
+ dense_shape=(2, 2)),
+ 'd2':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=['d2A', 'd2B', 'd2C'],
+ dense_shape=(2, 2)),
+ })
+ id_weight_pair = crossed2.get_sparse_tensors(transformation_cache, None)
+ with _initialized_session():
+ id_tensor_eval = id_weight_pair.id_tensor.eval()
+ self.assertAllEqual(
+ ((0, 0), (0, 1), (1, 0), (1, 1), (1, 2), (1, 3), (1, 4), (1, 5),
+ (1, 6), (1, 7), (1, 8), (1, 9), (1, 10), (1, 11), (1, 12), (1, 13),
+ (1, 14), (1, 15)),
+ id_tensor_eval.indices)
+ # Check exact hashed output. If hashing changes this test will break.
+ # All values are within [0, hash_bucket_size).
+ expected_values = (
+ 6, 14, 0, 13, 8, 8, 10, 12, 2, 0, 1, 9, 8, 12, 2, 0, 10, 11)
+ self.assertAllEqual(expected_values, id_tensor_eval.values)
+ self.assertAllEqual((2, 16), id_tensor_eval.dense_shape)
+
+ def test_get_sparse_tensors_simple(self):
+ """Same as test_get_sparse_tensors, but with simpler values."""
+ a = fc.numeric_column('a', dtype=dtypes.int32, shape=(2,))
+ b = fc.bucketized_column(a, boundaries=(0, 1))
+ crossed = fc.crossed_column([b, 'c'], hash_bucket_size=5, hash_key=5)
+ with ops.Graph().as_default():
+ transformation_cache = FeatureTransformationCache({
+ 'a':
+ constant_op.constant(((-1., .5), (.5, 1.))),
+ 'c':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=['cA', 'cB', 'cC'],
+ dense_shape=(2, 2)),
+ })
+ id_weight_pair = crossed.get_sparse_tensors(transformation_cache, None)
+ with _initialized_session():
+ id_tensor_eval = id_weight_pair.id_tensor.eval()
+ self.assertAllEqual(
+ ((0, 0), (0, 1), (1, 0), (1, 1), (1, 2), (1, 3)),
+ id_tensor_eval.indices)
+ # Check exact hashed output. If hashing changes this test will break.
+ # All values are within [0, hash_bucket_size).
+ expected_values = (1, 0, 1, 3, 4, 2)
+ self.assertAllEqual(expected_values, id_tensor_eval.values)
+ self.assertAllEqual((2, 4), id_tensor_eval.dense_shape)
+
+ def test_linear_model(self):
+ """Tests linear_model.
+
+ Uses data from test_get_sparse_tesnsors_simple.
+ """
+ a = fc_old.numeric_column('a', dtype=dtypes.int32, shape=(2,))
+ b = fc_old.bucketized_column(a, boundaries=(0, 1))
+ crossed = fc_old.crossed_column([b, 'c'], hash_bucket_size=5, hash_key=5)
+ with ops.Graph().as_default():
+ predictions = fc.linear_model({
+ 'a': constant_op.constant(((-1., .5), (.5, 1.))),
+ 'c': sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=['cA', 'cB', 'cC'],
+ dense_shape=(2, 2)),
+ }, (crossed,))
+ bias = get_linear_model_bias()
+ crossed_var = get_linear_model_column_var(crossed)
+ with _initialized_session() as sess:
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(
+ ((0.,), (0.,), (0.,), (0.,), (0.,)), crossed_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ sess.run(crossed_var.assign(((1.,), (2.,), (3.,), (4.,), (5.,))))
+ # Expected ids after cross = (1, 0, 1, 3, 4, 2)
+ self.assertAllClose(((3.,), (14.,)), predictions.eval())
+ sess.run(bias.assign((.1,)))
+ self.assertAllClose(((3.1,), (14.1,)), predictions.eval())
+
+ def test_linear_model_with_weights(self):
+
+ class _TestColumnWithWeights(fc_old._CategoricalColumn):
+ """Produces sparse IDs and sparse weights."""
+
+ @property
+ def name(self):
+ return 'test_column'
+
+ @property
+ def _parse_example_spec(self):
+ return {
+ self.name: parsing_ops.VarLenFeature(dtypes.int32),
+ '{}_weights'.format(self.name): parsing_ops.VarLenFeature(
+ dtypes.float32),
+ }
+
+ @property
+ def _num_buckets(self):
+ return 5
+
+ def _transform_feature(self, inputs):
+ return (inputs.get(self.name),
+ inputs.get('{}_weights'.format(self.name)))
+
+ def _get_sparse_tensors(self, inputs, weight_collections=None,
+ trainable=None):
+ """Populates both id_tensor and weight_tensor."""
+ ids_and_weights = inputs.get(self)
+ return fc_old._CategoricalColumn.IdWeightPair(
+ id_tensor=ids_and_weights[0], weight_tensor=ids_and_weights[1])
+
+ t = _TestColumnWithWeights()
+ crossed = fc_old.crossed_column([t, 'c'], hash_bucket_size=5, hash_key=5)
+ with ops.Graph().as_default():
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'crossed_column does not support weight_tensor.*{}'.format(t.name)):
+ fc.linear_model({
+ t.name: sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=[0, 1, 2],
+ dense_shape=(2, 2)),
+ '{}_weights'.format(t.name): sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=[1., 10., 2.],
+ dense_shape=(2, 2)),
+ 'c': sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=['cA', 'cB', 'cC'],
+ dense_shape=(2, 2)),
+ }, (crossed,))
+
+ def test_keras_linear_model(self):
+ """Tests _LinearModel.
+
+ Uses data from test_get_sparse_tesnsors_simple.
+ """
+ a = fc_old.numeric_column('a', dtype=dtypes.int32, shape=(2,))
+ b = fc_old.bucketized_column(a, boundaries=(0, 1))
+ crossed = fc_old.crossed_column([b, 'c'], hash_bucket_size=5, hash_key=5)
+ with ops.Graph().as_default():
+ predictions = get_keras_linear_model_predictions({
+ 'a':
+ constant_op.constant(((-1., .5), (.5, 1.))),
+ 'c':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=['cA', 'cB', 'cC'],
+ dense_shape=(2, 2)),
+ }, (crossed,))
+ bias = get_linear_model_bias()
+ crossed_var = get_linear_model_column_var(crossed)
+ with _initialized_session() as sess:
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,), (0.,), (0.,)),
+ crossed_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ sess.run(crossed_var.assign(((1.,), (2.,), (3.,), (4.,), (5.,))))
+ # Expected ids after cross = (1, 0, 1, 3, 4, 2)
+ self.assertAllClose(((3.,), (14.,)), predictions.eval())
+ sess.run(bias.assign((.1,)))
+ self.assertAllClose(((3.1,), (14.1,)), predictions.eval())
+
+ def test_keras_linear_model_with_weights(self):
+
+ class _TestColumnWithWeights(fc_old._CategoricalColumn):
+ """Produces sparse IDs and sparse weights."""
+
+ @property
+ def name(self):
+ return 'test_column'
+
+ @property
+ def _parse_example_spec(self):
+ return {
+ self.name:
+ parsing_ops.VarLenFeature(dtypes.int32),
+ '{}_weights'.format(self.name):
+ parsing_ops.VarLenFeature(dtypes.float32),
+ }
+
+ @property
+ def _num_buckets(self):
+ return 5
+
+ def _transform_feature(self, inputs):
+ return (inputs.get(self.name),
+ inputs.get('{}_weights'.format(self.name)))
+
+ def _get_sparse_tensors(self,
+ inputs,
+ weight_collections=None,
+ trainable=None):
+ """Populates both id_tensor and weight_tensor."""
+ ids_and_weights = inputs.get(self)
+ return fc_old._CategoricalColumn.IdWeightPair(
+ id_tensor=ids_and_weights[0], weight_tensor=ids_and_weights[1])
+
+ t = _TestColumnWithWeights()
+ crossed = fc_old.crossed_column([t, 'c'], hash_bucket_size=5, hash_key=5)
+ with ops.Graph().as_default():
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'crossed_column does not support weight_tensor.*{}'.format(t.name)):
+ get_keras_linear_model_predictions({
+ t.name:
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=[0, 1, 2],
+ dense_shape=(2, 2)),
+ '{}_weights'.format(t.name):
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=[1., 10., 2.],
+ dense_shape=(2, 2)),
+ 'c':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=['cA', 'cB', 'cC'],
+ dense_shape=(2, 2)),
+ }, (crossed,))
+
+
+def get_linear_model_bias(name='linear_model'):
+ with variable_scope.variable_scope(name, reuse=True):
+ return variable_scope.get_variable('bias_weights')
+
+
+def get_linear_model_column_var(column, name='linear_model'):
+ return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
+ name + '/' + column.name)[0]
+
+
+def get_keras_linear_model_predictions(features,
+ feature_columns,
+ units=1,
+ sparse_combiner='sum',
+ weight_collections=None,
+ trainable=True,
+ cols_to_vars=None):
+ keras_linear_model = _LinearModel(
+ feature_columns,
+ units,
+ sparse_combiner,
+ weight_collections,
+ trainable,
+ name='linear_model')
+ retval = keras_linear_model(features) # pylint: disable=not-callable
+ if cols_to_vars is not None:
+ cols_to_vars.update(keras_linear_model.cols_to_vars())
+ return retval
+
+
+class LinearModelTest(test.TestCase):
+
+ def test_raises_if_empty_feature_columns(self):
+ with self.assertRaisesRegexp(ValueError,
+ 'feature_columns must not be empty'):
+ fc.linear_model(features={}, feature_columns=[])
+
+ def test_should_be_feature_column(self):
+ with self.assertRaisesRegexp(ValueError, 'must be a _FeatureColumn'):
+ fc.linear_model(features={'a': [[0]]}, feature_columns='NotSupported')
+
+ def test_should_be_dense_or_categorical_column(self):
+
+ class NotSupportedColumn(fc_old._FeatureColumn):
+
+ @property
+ def name(self):
+ return 'NotSupportedColumn'
+
+ def _transform_feature(self, cache):
+ pass
+
+ @property
+ def _parse_example_spec(self):
+ pass
+
+ with self.assertRaisesRegexp(
+ ValueError, 'must be either a _DenseColumn or _CategoricalColumn'):
+ fc.linear_model(
+ features={'a': [[0]]}, feature_columns=[NotSupportedColumn()])
+
+ def test_does_not_support_dict_columns(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'Expected feature_columns to be iterable, found dict.'):
+ fc.linear_model(
+ features={'a': [[0]]},
+ feature_columns={'a': fc_old.numeric_column('a')})
+
+ def test_raises_if_duplicate_name(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'Duplicate feature column name found for columns'):
+ fc.linear_model(
+ features={'a': [[0]]},
+ feature_columns=[
+ fc_old.numeric_column('a'),
+ fc_old.numeric_column('a')
+ ])
+
+ def test_dense_bias(self):
+ price = fc_old.numeric_column('price')
+ with ops.Graph().as_default():
+ features = {'price': [[1.], [5.]]}
+ predictions = fc.linear_model(features, [price])
+ bias = get_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ sess.run(price_var.assign([[10.]]))
+ sess.run(bias.assign([5.]))
+ self.assertAllClose([[15.], [55.]], predictions.eval())
+
+ def test_sparse_bias(self):
+ wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default():
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ features = {'wire_cast': wire_tensor}
+ predictions = fc.linear_model(features, [wire_cast])
+ bias = get_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ self.assertAllClose([[0.], [0.], [0.], [0.]], wire_cast_var.eval())
+ sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(bias.assign([5.]))
+ self.assertAllClose([[1005.], [10015.]], predictions.eval())
+
+ def test_dense_and_sparse_bias(self):
+ wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ price = fc_old.numeric_column('price')
+ with ops.Graph().as_default():
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ features = {'wire_cast': wire_tensor, 'price': [[1.], [5.]]}
+ predictions = fc.linear_model(features, [wire_cast, price])
+ bias = get_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(bias.assign([5.]))
+ sess.run(price_var.assign([[10.]]))
+ self.assertAllClose([[1015.], [10065.]], predictions.eval())
+
+ def test_dense_and_sparse_column(self):
+ """When the column is both dense and sparse, uses sparse tensors."""
+
+ class _DenseAndSparseColumn(fc_old._DenseColumn, fc_old._CategoricalColumn):
+
+ @property
+ def name(self):
+ return 'dense_and_sparse_column'
+
+ @property
+ def _parse_example_spec(self):
+ return {self.name: parsing_ops.VarLenFeature(self.dtype)}
+
+ def _transform_feature(self, inputs):
+ return inputs.get(self.name)
+
+ @property
+ def _variable_shape(self):
+ raise ValueError('Should not use this method.')
+
+ def _get_dense_tensor(self, inputs, weight_collections=None,
+ trainable=None):
+ raise ValueError('Should not use this method.')
+
+ @property
+ def _num_buckets(self):
+ return 4
+
+ def _get_sparse_tensors(self, inputs, weight_collections=None,
+ trainable=None):
+ sp_tensor = sparse_tensor.SparseTensor(
+ indices=[[0, 0], [1, 0], [1, 1]],
+ values=[2, 0, 3],
+ dense_shape=[2, 2])
+ return fc_old._CategoricalColumn.IdWeightPair(sp_tensor, None)
+
+ dense_and_sparse_column = _DenseAndSparseColumn()
+ with ops.Graph().as_default():
+ sp_tensor = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'],
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ features = {dense_and_sparse_column.name: sp_tensor}
+ predictions = fc.linear_model(features, [dense_and_sparse_column])
+ bias = get_linear_model_bias()
+ dense_and_sparse_column_var = get_linear_model_column_var(
+ dense_and_sparse_column)
+ with _initialized_session() as sess:
+ sess.run(dense_and_sparse_column_var.assign(
+ [[10.], [100.], [1000.], [10000.]]))
+ sess.run(bias.assign([5.]))
+ self.assertAllClose([[1005.], [10015.]], predictions.eval())
+
+ def test_dense_multi_output(self):
+ price = fc_old.numeric_column('price')
+ with ops.Graph().as_default():
+ features = {'price': [[1.], [5.]]}
+ predictions = fc.linear_model(features, [price], units=3)
+ bias = get_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ self.assertAllClose(np.zeros((3,)), bias.eval())
+ self.assertAllClose(np.zeros((1, 3)), price_var.eval())
+ sess.run(price_var.assign([[10., 100., 1000.]]))
+ sess.run(bias.assign([5., 6., 7.]))
+ self.assertAllClose([[15., 106., 1007.], [55., 506., 5007.]],
+ predictions.eval())
+
+ def test_sparse_multi_output(self):
+ wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default():
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ features = {'wire_cast': wire_tensor}
+ predictions = fc.linear_model(features, [wire_cast], units=3)
+ bias = get_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ with _initialized_session() as sess:
+ self.assertAllClose(np.zeros((3,)), bias.eval())
+ self.assertAllClose(np.zeros((4, 3)), wire_cast_var.eval())
+ sess.run(
+ wire_cast_var.assign([[10., 11., 12.], [100., 110., 120.], [
+ 1000., 1100., 1200.
+ ], [10000., 11000., 12000.]]))
+ sess.run(bias.assign([5., 6., 7.]))
+ self.assertAllClose([[1005., 1106., 1207.], [10015., 11017., 12019.]],
+ predictions.eval())
+
+ def test_dense_multi_dimension(self):
+ price = fc_old.numeric_column('price', shape=2)
+ with ops.Graph().as_default():
+ features = {'price': [[1., 2.], [5., 6.]]}
+ predictions = fc.linear_model(features, [price])
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ self.assertAllClose([[0.], [0.]], price_var.eval())
+ sess.run(price_var.assign([[10.], [100.]]))
+ self.assertAllClose([[210.], [650.]], predictions.eval())
+
+ def test_sparse_multi_rank(self):
+ wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default():
+ wire_tensor = array_ops.sparse_placeholder(dtypes.string)
+ wire_value = sparse_tensor.SparseTensorValue(
+ values=['omar', 'stringer', 'marlo', 'omar'], # hashed = [2, 0, 3, 2]
+ indices=[[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 0, 1]],
+ dense_shape=[2, 2, 2])
+ features = {'wire_cast': wire_tensor}
+ predictions = fc.linear_model(features, [wire_cast])
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ with _initialized_session() as sess:
+ self.assertAllClose(np.zeros((4, 1)), wire_cast_var.eval())
+ self.assertAllClose(
+ np.zeros((2, 1)),
+ predictions.eval(feed_dict={wire_tensor: wire_value}))
+ sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
+ self.assertAllClose(
+ [[1010.], [11000.]],
+ predictions.eval(feed_dict={wire_tensor: wire_value}))
+
+ def test_sparse_combiner(self):
+ wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default():
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ features = {'wire_cast': wire_tensor}
+ predictions = fc.linear_model(
+ features, [wire_cast], sparse_combiner='mean')
+ bias = get_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ with _initialized_session() as sess:
+ sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(bias.assign([5.]))
+ self.assertAllClose([[1005.], [5010.]], predictions.eval())
+
+ def test_sparse_combiner_with_negative_weights(self):
+ wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ wire_cast_weights = fc_old.weighted_categorical_column(wire_cast, 'weights')
+
+ with ops.Graph().as_default():
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ features = {
+ 'wire_cast': wire_tensor,
+ 'weights': constant_op.constant([[1., 1., -1.0]])
+ }
+ predictions = fc.linear_model(
+ features, [wire_cast_weights], sparse_combiner='sum')
+ bias = get_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ with _initialized_session() as sess:
+ sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(bias.assign([5.]))
+ self.assertAllClose([[1005.], [-9985.]], predictions.eval())
+
+ def test_dense_multi_dimension_multi_output(self):
+ price = fc_old.numeric_column('price', shape=2)
+ with ops.Graph().as_default():
+ features = {'price': [[1., 2.], [5., 6.]]}
+ predictions = fc.linear_model(features, [price], units=3)
+ bias = get_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ self.assertAllClose(np.zeros((3,)), bias.eval())
+ self.assertAllClose(np.zeros((2, 3)), price_var.eval())
+ sess.run(price_var.assign([[1., 2., 3.], [10., 100., 1000.]]))
+ sess.run(bias.assign([2., 3., 4.]))
+ self.assertAllClose([[23., 205., 2007.], [67., 613., 6019.]],
+ predictions.eval())
+
+ def test_raises_if_shape_mismatch(self):
+ price = fc_old.numeric_column('price', shape=2)
+ with ops.Graph().as_default():
+ features = {'price': [[1.], [5.]]}
+ with self.assertRaisesRegexp(
+ Exception,
+ r'Cannot reshape a tensor with 2 elements to shape \[2,2\]'):
+ fc.linear_model(features, [price])
+
+ def test_dense_reshaping(self):
+ price = fc_old.numeric_column('price', shape=[1, 2])
+ with ops.Graph().as_default():
+ features = {'price': [[[1., 2.]], [[5., 6.]]]}
+ predictions = fc.linear_model(features, [price])
+ bias = get_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ self.assertAllClose([[0.], [0.]], price_var.eval())
+ self.assertAllClose([[0.], [0.]], predictions.eval())
+ sess.run(price_var.assign([[10.], [100.]]))
+ self.assertAllClose([[210.], [650.]], predictions.eval())
+
+ def test_dense_multi_column(self):
+ price1 = fc_old.numeric_column('price1', shape=2)
+ price2 = fc_old.numeric_column('price2')
+ with ops.Graph().as_default():
+ features = {
+ 'price1': [[1., 2.], [5., 6.]],
+ 'price2': [[3.], [4.]]
+ }
+ predictions = fc.linear_model(features, [price1, price2])
+ bias = get_linear_model_bias()
+ price1_var = get_linear_model_column_var(price1)
+ price2_var = get_linear_model_column_var(price2)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ self.assertAllClose([[0.], [0.]], price1_var.eval())
+ self.assertAllClose([[0.]], price2_var.eval())
+ self.assertAllClose([[0.], [0.]], predictions.eval())
+ sess.run(price1_var.assign([[10.], [100.]]))
+ sess.run(price2_var.assign([[1000.]]))
+ sess.run(bias.assign([7.]))
+ self.assertAllClose([[3217.], [4657.]], predictions.eval())
+
+ def test_fills_cols_to_vars(self):
+ price1 = fc_old.numeric_column('price1', shape=2)
+ price2 = fc_old.numeric_column('price2')
+ with ops.Graph().as_default():
+ features = {'price1': [[1., 2.], [5., 6.]], 'price2': [[3.], [4.]]}
+ cols_to_vars = {}
+ fc.linear_model(features, [price1, price2], cols_to_vars=cols_to_vars)
+ bias = get_linear_model_bias()
+ price1_var = get_linear_model_column_var(price1)
+ price2_var = get_linear_model_column_var(price2)
+ self.assertAllEqual(cols_to_vars['bias'], [bias])
+ self.assertAllEqual(cols_to_vars[price1], [price1_var])
+ self.assertAllEqual(cols_to_vars[price2], [price2_var])
+
+ def test_fills_cols_to_vars_partitioned_variables(self):
+ price1 = fc_old.numeric_column('price1', shape=2)
+ price2 = fc_old.numeric_column('price2', shape=3)
+ with ops.Graph().as_default():
+ features = {
+ 'price1': [[1., 2.], [6., 7.]],
+ 'price2': [[3., 4., 5.], [8., 9., 10.]]
+ }
+ cols_to_vars = {}
+ with variable_scope.variable_scope(
+ 'linear',
+ partitioner=partitioned_variables.fixed_size_partitioner(2, axis=0)):
+ fc.linear_model(features, [price1, price2], cols_to_vars=cols_to_vars)
+ with _initialized_session():
+ self.assertEqual([0.], cols_to_vars['bias'][0].eval())
+ # Partitioning shards the [2, 1] price1 var into 2 [1, 1] Variables.
+ self.assertAllEqual([[0.]], cols_to_vars[price1][0].eval())
+ self.assertAllEqual([[0.]], cols_to_vars[price1][1].eval())
+ # Partitioning shards the [3, 1] price2 var into a [2, 1] Variable and
+ # a [1, 1] Variable.
+ self.assertAllEqual([[0.], [0.]], cols_to_vars[price2][0].eval())
+ self.assertAllEqual([[0.]], cols_to_vars[price2][1].eval())
+
+ def test_dense_collection(self):
+ price = fc_old.numeric_column('price')
+ with ops.Graph().as_default() as g:
+ features = {'price': [[1.], [5.]]}
+ fc.linear_model(features, [price], weight_collections=['my-vars'])
+ my_vars = g.get_collection('my-vars')
+ bias = get_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ self.assertIn(bias, my_vars)
+ self.assertIn(price_var, my_vars)
+
+ def test_sparse_collection(self):
+ wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default() as g:
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
+ features = {'wire_cast': wire_tensor}
+ fc.linear_model(
+ features, [wire_cast], weight_collections=['my-vars'])
+ my_vars = g.get_collection('my-vars')
+ bias = get_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ self.assertIn(bias, my_vars)
+ self.assertIn(wire_cast_var, my_vars)
+
+ def test_dense_trainable_default(self):
+ price = fc_old.numeric_column('price')
+ with ops.Graph().as_default() as g:
+ features = {'price': [[1.], [5.]]}
+ fc.linear_model(features, [price])
+ bias = get_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ self.assertIn(bias, trainable_vars)
+ self.assertIn(price_var, trainable_vars)
+
+ def test_sparse_trainable_default(self):
+ wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default() as g:
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
+ features = {'wire_cast': wire_tensor}
+ fc.linear_model(features, [wire_cast])
+ trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ bias = get_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ self.assertIn(bias, trainable_vars)
+ self.assertIn(wire_cast_var, trainable_vars)
+
+ def test_dense_trainable_false(self):
+ price = fc_old.numeric_column('price')
+ with ops.Graph().as_default() as g:
+ features = {'price': [[1.], [5.]]}
+ fc.linear_model(features, [price], trainable=False)
+ trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ self.assertEqual([], trainable_vars)
+
+ def test_sparse_trainable_false(self):
+ wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default() as g:
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
+ features = {'wire_cast': wire_tensor}
+ fc.linear_model(features, [wire_cast], trainable=False)
+ trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ self.assertEqual([], trainable_vars)
+
+ def test_column_order(self):
+ price_a = fc_old.numeric_column('price_a')
+ price_b = fc_old.numeric_column('price_b')
+ wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default() as g:
+ features = {
+ 'price_a': [[1.]],
+ 'price_b': [[3.]],
+ 'wire_cast':
+ sparse_tensor.SparseTensor(
+ values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
+ }
+ fc.linear_model(
+ features, [price_a, wire_cast, price_b],
+ weight_collections=['my-vars'])
+ my_vars = g.get_collection('my-vars')
+ self.assertIn('price_a', my_vars[0].name)
+ self.assertIn('price_b', my_vars[1].name)
+ self.assertIn('wire_cast', my_vars[2].name)
+
+ with ops.Graph().as_default() as g:
+ features = {
+ 'price_a': [[1.]],
+ 'price_b': [[3.]],
+ 'wire_cast':
+ sparse_tensor.SparseTensor(
+ values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
+ }
+ fc.linear_model(
+ features, [wire_cast, price_b, price_a],
+ weight_collections=['my-vars'])
+ my_vars = g.get_collection('my-vars')
+ self.assertIn('price_a', my_vars[0].name)
+ self.assertIn('price_b', my_vars[1].name)
+ self.assertIn('wire_cast', my_vars[2].name)
+
+ def test_static_batch_size_mismatch(self):
+ price1 = fc_old.numeric_column('price1')
+ price2 = fc_old.numeric_column('price2')
+ with ops.Graph().as_default():
+ features = {
+ 'price1': [[1.], [5.], [7.]], # batchsize = 3
+ 'price2': [[3.], [4.]] # batchsize = 2
+ }
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string
+ fc.linear_model(features, [price1, price2])
+
+ def test_subset_of_static_batch_size_mismatch(self):
+ price1 = fc_old.numeric_column('price1')
+ price2 = fc_old.numeric_column('price2')
+ price3 = fc_old.numeric_column('price3')
+ with ops.Graph().as_default():
+ features = {
+ 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3
+ 'price2': [[3.], [4.]], # batchsize = 2
+ 'price3': [[3.], [4.], [5.]] # batchsize = 3
+ }
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string
+ fc.linear_model(features, [price1, price2, price3])
+
+ def test_runtime_batch_size_mismatch(self):
+ price1 = fc_old.numeric_column('price1')
+ price2 = fc_old.numeric_column('price2')
+ with ops.Graph().as_default():
+ features = {
+ 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3
+ 'price2': [[3.], [4.]] # batchsize = 2
+ }
+ predictions = fc.linear_model(features, [price1, price2])
+ with _initialized_session() as sess:
+ with self.assertRaisesRegexp(errors.OpError,
+ 'must have the same size and shape'):
+ sess.run(
+ predictions, feed_dict={features['price1']: [[1.], [5.], [7.]]})
+
+ def test_runtime_batch_size_matches(self):
+ price1 = fc_old.numeric_column('price1')
+ price2 = fc_old.numeric_column('price2')
+ with ops.Graph().as_default():
+ features = {
+ 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2
+ 'price2': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2
+ }
+ predictions = fc.linear_model(features, [price1, price2])
+ with _initialized_session() as sess:
+ sess.run(
+ predictions,
+ feed_dict={
+ features['price1']: [[1.], [5.]],
+ features['price2']: [[1.], [5.]],
+ })
+
+ def test_with_numpy_input_fn(self):
+ price = fc_old.numeric_column('price')
+ price_buckets = fc_old.bucketized_column(
+ price, boundaries=[
+ 0.,
+ 10.,
+ 100.,
+ ])
+ body_style = fc_old.categorical_column_with_vocabulary_list(
+ 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
+
+ input_fn = numpy_io.numpy_input_fn(
+ x={
+ 'price': np.array([-1., 2., 13., 104.]),
+ 'body-style': np.array(['sedan', 'hardtop', 'wagon', 'sedan']),
+ },
+ batch_size=2,
+ shuffle=False)
+ features = input_fn()
+ net = fc.linear_model(features, [price_buckets, body_style])
+ # self.assertEqual(1 + 3 + 5, net.shape[1])
+ with _initialized_session() as sess:
+ coord = coordinator.Coordinator()
+ threads = queue_runner_impl.start_queue_runners(sess, coord=coord)
+
+ bias = get_linear_model_bias()
+ price_buckets_var = get_linear_model_column_var(price_buckets)
+ body_style_var = get_linear_model_column_var(body_style)
+
+ sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
+ sess.run(bias.assign([5.]))
+
+ self.assertAllClose([[10 - 1000 + 5.], [100 - 10 + 5.]], sess.run(net))
+
+ coord.request_stop()
+ coord.join(threads)
+
+ def test_with_1d_sparse_tensor(self):
+ price = fc_old.numeric_column('price')
+ price_buckets = fc_old.bucketized_column(
+ price, boundaries=[
+ 0.,
+ 10.,
+ 100.,
+ ])
+ body_style = fc_old.categorical_column_with_vocabulary_list(
+ 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
+
+ # Provides 1-dim tensor and dense tensor.
+ features = {
+ 'price': constant_op.constant([-1., 12.,]),
+ 'body-style': sparse_tensor.SparseTensor(
+ indices=((0,), (1,)),
+ values=('sedan', 'hardtop'),
+ dense_shape=(2,)),
+ }
+ self.assertEqual(1, features['price'].shape.ndims)
+ self.assertEqual(1, features['body-style'].dense_shape.get_shape()[0])
+
+ net = fc.linear_model(features, [price_buckets, body_style])
+ with _initialized_session() as sess:
+ bias = get_linear_model_bias()
+ price_buckets_var = get_linear_model_column_var(price_buckets)
+ body_style_var = get_linear_model_column_var(body_style)
+
+ sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
+ sess.run(bias.assign([5.]))
+
+ self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], sess.run(net))
+
+ def test_with_1d_unknown_shape_sparse_tensor(self):
+ price = fc_old.numeric_column('price')
+ price_buckets = fc_old.bucketized_column(
+ price, boundaries=[
+ 0.,
+ 10.,
+ 100.,
+ ])
+ body_style = fc_old.categorical_column_with_vocabulary_list(
+ 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
+ country = fc_old.categorical_column_with_vocabulary_list(
+ 'country', vocabulary_list=['US', 'JP', 'CA'])
+
+ # Provides 1-dim tensor and dense tensor.
+ features = {
+ 'price': array_ops.placeholder(dtypes.float32),
+ 'body-style': array_ops.sparse_placeholder(dtypes.string),
+ 'country': array_ops.placeholder(dtypes.string),
+ }
+ self.assertIsNone(features['price'].shape.ndims)
+ self.assertIsNone(features['body-style'].get_shape().ndims)
+
+ price_data = np.array([-1., 12.])
+ body_style_data = sparse_tensor.SparseTensorValue(
+ indices=((0,), (1,)),
+ values=('sedan', 'hardtop'),
+ dense_shape=(2,))
+ country_data = np.array(['US', 'CA'])
+
+ net = fc.linear_model(features, [price_buckets, body_style, country])
+ bias = get_linear_model_bias()
+ price_buckets_var = get_linear_model_column_var(price_buckets)
+ body_style_var = get_linear_model_column_var(body_style)
+ with _initialized_session() as sess:
+ sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
+ sess.run(bias.assign([5.]))
+
+ self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]],
+ sess.run(
+ net,
+ feed_dict={
+ features['price']: price_data,
+ features['body-style']: body_style_data,
+ features['country']: country_data
+ }))
+
+ def test_with_rank_0_feature(self):
+ price = fc_old.numeric_column('price')
+ features = {
+ 'price': constant_op.constant(0),
+ }
+ self.assertEqual(0, features['price'].shape.ndims)
+
+ # Static rank 0 should fail
+ with self.assertRaisesRegexp(ValueError, 'Feature .* cannot have rank 0'):
+ fc.linear_model(features, [price])
+
+ # Dynamic rank 0 should fail
+ features = {
+ 'price': array_ops.placeholder(dtypes.float32),
+ }
+ net = fc.linear_model(features, [price])
+ self.assertEqual(1, net.shape[1])
+ with _initialized_session() as sess:
+ with self.assertRaisesOpError('Feature .* cannot have rank 0'):
+ sess.run(net, feed_dict={features['price']: np.array(1)})
+
+ def test_multiple_linear_models(self):
+ price = fc_old.numeric_column('price')
+ with ops.Graph().as_default():
+ features1 = {'price': [[1.], [5.]]}
+ features2 = {'price': [[2.], [10.]]}
+ predictions1 = fc.linear_model(features1, [price])
+ predictions2 = fc.linear_model(features2, [price])
+ bias1 = get_linear_model_bias(name='linear_model')
+ bias2 = get_linear_model_bias(name='linear_model_1')
+ price_var1 = get_linear_model_column_var(price, name='linear_model')
+ price_var2 = get_linear_model_column_var(price, name='linear_model_1')
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias1.eval())
+ sess.run(price_var1.assign([[10.]]))
+ sess.run(bias1.assign([5.]))
+ self.assertAllClose([[15.], [55.]], predictions1.eval())
+ self.assertAllClose([0.], bias2.eval())
+ sess.run(price_var2.assign([[10.]]))
+ sess.run(bias2.assign([5.]))
+ self.assertAllClose([[25.], [105.]], predictions2.eval())
+
+
+class _LinearModelTest(test.TestCase):
+
+ def test_raises_if_empty_feature_columns(self):
+ with self.assertRaisesRegexp(ValueError,
+ 'feature_columns must not be empty'):
+ get_keras_linear_model_predictions(features={}, feature_columns=[])
+
+ def test_should_be_feature_column(self):
+ with self.assertRaisesRegexp(ValueError, 'must be a _FeatureColumn'):
+ get_keras_linear_model_predictions(
+ features={'a': [[0]]}, feature_columns='NotSupported')
+
+ def test_should_be_dense_or_categorical_column(self):
+
+ class NotSupportedColumn(fc_old._FeatureColumn):
+
+ @property
+ def name(self):
+ return 'NotSupportedColumn'
+
+ def _transform_feature(self, cache):
+ pass
+
+ @property
+ def _parse_example_spec(self):
+ pass
+
+ with self.assertRaisesRegexp(
+ ValueError, 'must be either a _DenseColumn or _CategoricalColumn'):
+ get_keras_linear_model_predictions(
+ features={'a': [[0]]}, feature_columns=[NotSupportedColumn()])
+
+ def test_does_not_support_dict_columns(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'Expected feature_columns to be iterable, found dict.'):
+ fc.linear_model(
+ features={'a': [[0]]},
+ feature_columns={'a': fc_old.numeric_column('a')})
+
+ def test_raises_if_duplicate_name(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'Duplicate feature column name found for columns'):
+ get_keras_linear_model_predictions(
+ features={'a': [[0]]},
+ feature_columns=[
+ fc_old.numeric_column('a'),
+ fc_old.numeric_column('a')
+ ])
+
+ def test_dense_bias(self):
+ price = fc_old.numeric_column('price')
+ with ops.Graph().as_default():
+ features = {'price': [[1.], [5.]]}
+ predictions = get_keras_linear_model_predictions(features, [price])
+ bias = get_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ sess.run(price_var.assign([[10.]]))
+ sess.run(bias.assign([5.]))
+ self.assertAllClose([[15.], [55.]], predictions.eval())
+
+ def test_sparse_bias(self):
+ wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default():
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ features = {'wire_cast': wire_tensor}
+ predictions = get_keras_linear_model_predictions(features, [wire_cast])
+ bias = get_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ self.assertAllClose([[0.], [0.], [0.], [0.]], wire_cast_var.eval())
+ sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(bias.assign([5.]))
+ self.assertAllClose([[1005.], [10015.]], predictions.eval())
+
+ def test_dense_and_sparse_bias(self):
+ wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ price = fc_old.numeric_column('price')
+ with ops.Graph().as_default():
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ features = {'wire_cast': wire_tensor, 'price': [[1.], [5.]]}
+ predictions = get_keras_linear_model_predictions(features,
+ [wire_cast, price])
+ bias = get_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(bias.assign([5.]))
+ sess.run(price_var.assign([[10.]]))
+ self.assertAllClose([[1015.], [10065.]], predictions.eval())
+
+ def test_dense_and_sparse_column(self):
+ """When the column is both dense and sparse, uses sparse tensors."""
+
+ class _DenseAndSparseColumn(fc_old._DenseColumn, fc_old._CategoricalColumn):
+
+ @property
+ def name(self):
+ return 'dense_and_sparse_column'
+
+ @property
+ def _parse_example_spec(self):
+ return {self.name: parsing_ops.VarLenFeature(self.dtype)}
+
+ def _transform_feature(self, inputs):
+ return inputs.get(self.name)
+
+ @property
+ def _variable_shape(self):
+ raise ValueError('Should not use this method.')
+
+ def _get_dense_tensor(self,
+ inputs,
+ weight_collections=None,
+ trainable=None):
+ raise ValueError('Should not use this method.')
+
+ @property
+ def _num_buckets(self):
+ return 4
+
+ def _get_sparse_tensors(self,
+ inputs,
+ weight_collections=None,
+ trainable=None):
+ sp_tensor = sparse_tensor.SparseTensor(
+ indices=[[0, 0], [1, 0], [1, 1]],
+ values=[2, 0, 3],
+ dense_shape=[2, 2])
+ return fc_old._CategoricalColumn.IdWeightPair(sp_tensor, None)
+
+ dense_and_sparse_column = _DenseAndSparseColumn()
+ with ops.Graph().as_default():
+ sp_tensor = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'],
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ features = {dense_and_sparse_column.name: sp_tensor}
+ predictions = get_keras_linear_model_predictions(
+ features, [dense_and_sparse_column])
+ bias = get_linear_model_bias()
+ dense_and_sparse_column_var = get_linear_model_column_var(
+ dense_and_sparse_column)
+ with _initialized_session() as sess:
+ sess.run(
+ dense_and_sparse_column_var.assign([[10.], [100.], [1000.],
+ [10000.]]))
+ sess.run(bias.assign([5.]))
+ self.assertAllClose([[1005.], [10015.]], predictions.eval())
+
+ def test_dense_multi_output(self):
+ price = fc_old.numeric_column('price')
+ with ops.Graph().as_default():
+ features = {'price': [[1.], [5.]]}
+ predictions = get_keras_linear_model_predictions(
+ features, [price], units=3)
+ bias = get_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ self.assertAllClose(np.zeros((3,)), bias.eval())
+ self.assertAllClose(np.zeros((1, 3)), price_var.eval())
+ sess.run(price_var.assign([[10., 100., 1000.]]))
+ sess.run(bias.assign([5., 6., 7.]))
+ self.assertAllClose([[15., 106., 1007.], [55., 506., 5007.]],
+ predictions.eval())
+
+ def test_sparse_multi_output(self):
+ wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default():
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ features = {'wire_cast': wire_tensor}
+ predictions = get_keras_linear_model_predictions(
+ features, [wire_cast], units=3)
+ bias = get_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ with _initialized_session() as sess:
+ self.assertAllClose(np.zeros((3,)), bias.eval())
+ self.assertAllClose(np.zeros((4, 3)), wire_cast_var.eval())
+ sess.run(
+ wire_cast_var.assign([[10., 11., 12.], [100., 110., 120.],
+ [1000., 1100.,
+ 1200.], [10000., 11000., 12000.]]))
+ sess.run(bias.assign([5., 6., 7.]))
+ self.assertAllClose([[1005., 1106., 1207.], [10015., 11017., 12019.]],
+ predictions.eval())
+
+ def test_dense_multi_dimension(self):
+ price = fc_old.numeric_column('price', shape=2)
+ with ops.Graph().as_default():
+ features = {'price': [[1., 2.], [5., 6.]]}
+ predictions = get_keras_linear_model_predictions(features, [price])
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ self.assertAllClose([[0.], [0.]], price_var.eval())
+ sess.run(price_var.assign([[10.], [100.]]))
+ self.assertAllClose([[210.], [650.]], predictions.eval())
+
+ def test_sparse_multi_rank(self):
+ wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default():
+ wire_tensor = array_ops.sparse_placeholder(dtypes.string)
+ wire_value = sparse_tensor.SparseTensorValue(
+ values=['omar', 'stringer', 'marlo', 'omar'], # hashed = [2, 0, 3, 2]
+ indices=[[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 0, 1]],
+ dense_shape=[2, 2, 2])
+ features = {'wire_cast': wire_tensor}
+ predictions = get_keras_linear_model_predictions(features, [wire_cast])
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ with _initialized_session() as sess:
+ self.assertAllClose(np.zeros((4, 1)), wire_cast_var.eval())
+ self.assertAllClose(
+ np.zeros((2, 1)),
+ predictions.eval(feed_dict={wire_tensor: wire_value}))
+ sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
+ self.assertAllClose(
+ [[1010.], [11000.]],
+ predictions.eval(feed_dict={wire_tensor: wire_value}))
+
+ def test_sparse_combiner(self):
+ wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default():
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ features = {'wire_cast': wire_tensor}
+ predictions = get_keras_linear_model_predictions(
+ features, [wire_cast], sparse_combiner='mean')
+ bias = get_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ with _initialized_session() as sess:
+ sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(bias.assign([5.]))
+ self.assertAllClose([[1005.], [5010.]], predictions.eval())
+
+ def test_dense_multi_dimension_multi_output(self):
+ price = fc_old.numeric_column('price', shape=2)
+ with ops.Graph().as_default():
+ features = {'price': [[1., 2.], [5., 6.]]}
+ predictions = get_keras_linear_model_predictions(
+ features, [price], units=3)
+ bias = get_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ self.assertAllClose(np.zeros((3,)), bias.eval())
+ self.assertAllClose(np.zeros((2, 3)), price_var.eval())
+ sess.run(price_var.assign([[1., 2., 3.], [10., 100., 1000.]]))
+ sess.run(bias.assign([2., 3., 4.]))
+ self.assertAllClose([[23., 205., 2007.], [67., 613., 6019.]],
+ predictions.eval())
+
+ def test_raises_if_shape_mismatch(self):
+ price = fc_old.numeric_column('price', shape=2)
+ with ops.Graph().as_default():
+ features = {'price': [[1.], [5.]]}
+ with self.assertRaisesRegexp(
+ Exception,
+ r'Cannot reshape a tensor with 2 elements to shape \[2,2\]'):
+ get_keras_linear_model_predictions(features, [price])
+
+ def test_dense_reshaping(self):
+ price = fc_old.numeric_column('price', shape=[1, 2])
+ with ops.Graph().as_default():
+ features = {'price': [[[1., 2.]], [[5., 6.]]]}
+ predictions = get_keras_linear_model_predictions(features, [price])
+ bias = get_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ self.assertAllClose([[0.], [0.]], price_var.eval())
+ self.assertAllClose([[0.], [0.]], predictions.eval())
+ sess.run(price_var.assign([[10.], [100.]]))
+ self.assertAllClose([[210.], [650.]], predictions.eval())
+
+ def test_dense_multi_column(self):
+ price1 = fc_old.numeric_column('price1', shape=2)
+ price2 = fc_old.numeric_column('price2')
+ with ops.Graph().as_default():
+ features = {'price1': [[1., 2.], [5., 6.]], 'price2': [[3.], [4.]]}
+ predictions = get_keras_linear_model_predictions(features,
+ [price1, price2])
+ bias = get_linear_model_bias()
+ price1_var = get_linear_model_column_var(price1)
+ price2_var = get_linear_model_column_var(price2)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ self.assertAllClose([[0.], [0.]], price1_var.eval())
+ self.assertAllClose([[0.]], price2_var.eval())
+ self.assertAllClose([[0.], [0.]], predictions.eval())
+ sess.run(price1_var.assign([[10.], [100.]]))
+ sess.run(price2_var.assign([[1000.]]))
+ sess.run(bias.assign([7.]))
+ self.assertAllClose([[3217.], [4657.]], predictions.eval())
+
+ def test_fills_cols_to_vars(self):
+ price1 = fc_old.numeric_column('price1', shape=2)
+ price2 = fc_old.numeric_column('price2')
+ with ops.Graph().as_default():
+ features = {'price1': [[1., 2.], [5., 6.]], 'price2': [[3.], [4.]]}
+ cols_to_vars = {}
+ get_keras_linear_model_predictions(
+ features, [price1, price2], cols_to_vars=cols_to_vars)
+ bias = get_linear_model_bias()
+ price1_var = get_linear_model_column_var(price1)
+ price2_var = get_linear_model_column_var(price2)
+ self.assertAllEqual(cols_to_vars['bias'], [bias])
+ self.assertAllEqual(cols_to_vars[price1], [price1_var])
+ self.assertAllEqual(cols_to_vars[price2], [price2_var])
+
+ def test_fills_cols_to_vars_partitioned_variables(self):
+ price1 = fc_old.numeric_column('price1', shape=2)
+ price2 = fc_old.numeric_column('price2', shape=3)
+ with ops.Graph().as_default():
+ features = {
+ 'price1': [[1., 2.], [6., 7.]],
+ 'price2': [[3., 4., 5.], [8., 9., 10.]]
+ }
+ cols_to_vars = {}
+ with variable_scope.variable_scope(
+ 'linear',
+ partitioner=partitioned_variables.fixed_size_partitioner(2, axis=0)):
+ get_keras_linear_model_predictions(
+ features, [price1, price2], cols_to_vars=cols_to_vars)
+ with _initialized_session():
+ self.assertEqual([0.], cols_to_vars['bias'][0].eval())
+ # Partitioning shards the [2, 1] price1 var into 2 [1, 1] Variables.
+ self.assertAllEqual([[0.]], cols_to_vars[price1][0].eval())
+ self.assertAllEqual([[0.]], cols_to_vars[price1][1].eval())
+ # Partitioning shards the [3, 1] price2 var into a [2, 1] Variable and
+ # a [1, 1] Variable.
+ self.assertAllEqual([[0.], [0.]], cols_to_vars[price2][0].eval())
+ self.assertAllEqual([[0.]], cols_to_vars[price2][1].eval())
+
+ def test_dense_collection(self):
+ price = fc_old.numeric_column('price')
+ with ops.Graph().as_default() as g:
+ features = {'price': [[1.], [5.]]}
+ get_keras_linear_model_predictions(
+ features, [price], weight_collections=['my-vars'])
+ my_vars = g.get_collection('my-vars')
+ bias = get_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ self.assertIn(bias, my_vars)
+ self.assertIn(price_var, my_vars)
+
+ def test_sparse_collection(self):
+ wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default() as g:
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
+ features = {'wire_cast': wire_tensor}
+ get_keras_linear_model_predictions(
+ features, [wire_cast], weight_collections=['my-vars'])
+ my_vars = g.get_collection('my-vars')
+ bias = get_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ self.assertIn(bias, my_vars)
+ self.assertIn(wire_cast_var, my_vars)
+
+ def test_dense_trainable_default(self):
+ price = fc_old.numeric_column('price')
+ with ops.Graph().as_default() as g:
+ features = {'price': [[1.], [5.]]}
+ get_keras_linear_model_predictions(features, [price])
+ bias = get_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ self.assertIn(bias, trainable_vars)
+ self.assertIn(price_var, trainable_vars)
+
+ def test_sparse_trainable_default(self):
+ wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default() as g:
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
+ features = {'wire_cast': wire_tensor}
+ get_keras_linear_model_predictions(features, [wire_cast])
+ trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ bias = get_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ self.assertIn(bias, trainable_vars)
+ self.assertIn(wire_cast_var, trainable_vars)
+
+ def test_dense_trainable_false(self):
+ price = fc_old.numeric_column('price')
+ with ops.Graph().as_default() as g:
+ features = {'price': [[1.], [5.]]}
+ get_keras_linear_model_predictions(features, [price], trainable=False)
+ trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ self.assertEqual([], trainable_vars)
+
+ def test_sparse_trainable_false(self):
+ wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default() as g:
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
+ features = {'wire_cast': wire_tensor}
+ get_keras_linear_model_predictions(features, [wire_cast], trainable=False)
+ trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ self.assertEqual([], trainable_vars)
+
+ def test_column_order(self):
+ price_a = fc_old.numeric_column('price_a')
+ price_b = fc_old.numeric_column('price_b')
+ wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default() as g:
+ features = {
+ 'price_a': [[1.]],
+ 'price_b': [[3.]],
+ 'wire_cast':
+ sparse_tensor.SparseTensor(
+ values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
+ }
+ get_keras_linear_model_predictions(
+ features, [price_a, wire_cast, price_b],
+ weight_collections=['my-vars'])
+ my_vars = g.get_collection('my-vars')
+ self.assertIn('price_a', my_vars[0].name)
+ self.assertIn('price_b', my_vars[1].name)
+ self.assertIn('wire_cast', my_vars[2].name)
+
+ with ops.Graph().as_default() as g:
+ features = {
+ 'price_a': [[1.]],
+ 'price_b': [[3.]],
+ 'wire_cast':
+ sparse_tensor.SparseTensor(
+ values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
+ }
+ get_keras_linear_model_predictions(
+ features, [wire_cast, price_b, price_a],
+ weight_collections=['my-vars'])
+ my_vars = g.get_collection('my-vars')
+ self.assertIn('price_a', my_vars[0].name)
+ self.assertIn('price_b', my_vars[1].name)
+ self.assertIn('wire_cast', my_vars[2].name)
+
+ def test_static_batch_size_mismatch(self):
+ price1 = fc_old.numeric_column('price1')
+ price2 = fc_old.numeric_column('price2')
+ with ops.Graph().as_default():
+ features = {
+ 'price1': [[1.], [5.], [7.]], # batchsize = 3
+ 'price2': [[3.], [4.]] # batchsize = 2
+ }
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string
+ get_keras_linear_model_predictions(features, [price1, price2])
+
+ def test_subset_of_static_batch_size_mismatch(self):
+ price1 = fc_old.numeric_column('price1')
+ price2 = fc_old.numeric_column('price2')
+ price3 = fc_old.numeric_column('price3')
+ with ops.Graph().as_default():
+ features = {
+ 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3
+ 'price2': [[3.], [4.]], # batchsize = 2
+ 'price3': [[3.], [4.], [5.]] # batchsize = 3
+ }
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string
+ get_keras_linear_model_predictions(features, [price1, price2, price3])
+
+ def test_runtime_batch_size_mismatch(self):
+ price1 = fc_old.numeric_column('price1')
+ price2 = fc_old.numeric_column('price2')
+ with ops.Graph().as_default():
+ features = {
+ 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3
+ 'price2': [[3.], [4.]] # batchsize = 2
+ }
+ predictions = get_keras_linear_model_predictions(features,
+ [price1, price2])
+ with _initialized_session() as sess:
+ with self.assertRaisesRegexp(errors.OpError,
+ 'must have the same size and shape'):
+ sess.run(
+ predictions, feed_dict={features['price1']: [[1.], [5.], [7.]]})
+
+ def test_runtime_batch_size_matches(self):
+ price1 = fc_old.numeric_column('price1')
+ price2 = fc_old.numeric_column('price2')
+ with ops.Graph().as_default():
+ features = {
+ 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2
+ 'price2': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2
+ }
+ predictions = get_keras_linear_model_predictions(features,
+ [price1, price2])
+ with _initialized_session() as sess:
+ sess.run(
+ predictions,
+ feed_dict={
+ features['price1']: [[1.], [5.]],
+ features['price2']: [[1.], [5.]],
+ })
+
+ def test_with_numpy_input_fn(self):
+ price = fc_old.numeric_column('price')
+ price_buckets = fc_old.bucketized_column(
+ price, boundaries=[
+ 0.,
+ 10.,
+ 100.,
+ ])
+ body_style = fc_old.categorical_column_with_vocabulary_list(
+ 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
+
+ input_fn = numpy_io.numpy_input_fn(
+ x={
+ 'price': np.array([-1., 2., 13., 104.]),
+ 'body-style': np.array(['sedan', 'hardtop', 'wagon', 'sedan']),
+ },
+ batch_size=2,
+ shuffle=False)
+ features = input_fn()
+ net = get_keras_linear_model_predictions(features,
+ [price_buckets, body_style])
+ # self.assertEqual(1 + 3 + 5, net.shape[1])
+ with _initialized_session() as sess:
+ coord = coordinator.Coordinator()
+ threads = queue_runner_impl.start_queue_runners(sess, coord=coord)
+
+ bias = get_linear_model_bias()
+ price_buckets_var = get_linear_model_column_var(price_buckets)
+ body_style_var = get_linear_model_column_var(body_style)
+
+ sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
+ sess.run(bias.assign([5.]))
+
+ self.assertAllClose([[10 - 1000 + 5.], [100 - 10 + 5.]], sess.run(net))
+
+ coord.request_stop()
+ coord.join(threads)
+
+ def test_with_1d_sparse_tensor(self):
+ price = fc_old.numeric_column('price')
+ price_buckets = fc_old.bucketized_column(
+ price, boundaries=[
+ 0.,
+ 10.,
+ 100.,
+ ])
+ body_style = fc_old.categorical_column_with_vocabulary_list(
+ 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
+
+ # Provides 1-dim tensor and dense tensor.
+ features = {
+ 'price':
+ constant_op.constant([
+ -1.,
+ 12.,
+ ]),
+ 'body-style':
+ sparse_tensor.SparseTensor(
+ indices=((0,), (1,)),
+ values=('sedan', 'hardtop'),
+ dense_shape=(2,)),
+ }
+ self.assertEqual(1, features['price'].shape.ndims)
+ self.assertEqual(1, features['body-style'].dense_shape.get_shape()[0])
+
+ net = get_keras_linear_model_predictions(features,
+ [price_buckets, body_style])
+ with _initialized_session() as sess:
+ bias = get_linear_model_bias()
+ price_buckets_var = get_linear_model_column_var(price_buckets)
+ body_style_var = get_linear_model_column_var(body_style)
+
+ sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
+ sess.run(bias.assign([5.]))
+
+ self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], sess.run(net))
+
+ def test_with_1d_unknown_shape_sparse_tensor(self):
+ price = fc_old.numeric_column('price')
+ price_buckets = fc_old.bucketized_column(
+ price, boundaries=[
+ 0.,
+ 10.,
+ 100.,
+ ])
+ body_style = fc_old.categorical_column_with_vocabulary_list(
+ 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
+ country = fc_old.categorical_column_with_vocabulary_list(
+ 'country', vocabulary_list=['US', 'JP', 'CA'])
+
+ # Provides 1-dim tensor and dense tensor.
+ features = {
+ 'price': array_ops.placeholder(dtypes.float32),
+ 'body-style': array_ops.sparse_placeholder(dtypes.string),
+ 'country': array_ops.placeholder(dtypes.string),
+ }
+ self.assertIsNone(features['price'].shape.ndims)
+ self.assertIsNone(features['body-style'].get_shape().ndims)
+
+ price_data = np.array([-1., 12.])
+ body_style_data = sparse_tensor.SparseTensorValue(
+ indices=((0,), (1,)), values=('sedan', 'hardtop'), dense_shape=(2,))
+ country_data = np.array(['US', 'CA'])
+
+ net = get_keras_linear_model_predictions(
+ features, [price_buckets, body_style, country])
+ bias = get_linear_model_bias()
+ price_buckets_var = get_linear_model_column_var(price_buckets)
+ body_style_var = get_linear_model_column_var(body_style)
+ with _initialized_session() as sess:
+ sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
+ sess.run(bias.assign([5.]))
+
+ self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]],
+ sess.run(
+ net,
+ feed_dict={
+ features['price']: price_data,
+ features['body-style']: body_style_data,
+ features['country']: country_data
+ }))
+
+ def test_with_rank_0_feature(self):
+ price = fc_old.numeric_column('price')
+ features = {
+ 'price': constant_op.constant(0),
+ }
+ self.assertEqual(0, features['price'].shape.ndims)
+
+ # Static rank 0 should fail
+ with self.assertRaisesRegexp(ValueError, 'Feature .* cannot have rank 0'):
+ get_keras_linear_model_predictions(features, [price])
+
+ # Dynamic rank 0 should fail
+ features = {
+ 'price': array_ops.placeholder(dtypes.float32),
+ }
+ net = get_keras_linear_model_predictions(features, [price])
+ self.assertEqual(1, net.shape[1])
+ with _initialized_session() as sess:
+ with self.assertRaisesOpError('Feature .* cannot have rank 0'):
+ sess.run(net, feed_dict={features['price']: np.array(1)})
+
+
+class InputLayerTest(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_retrieving_input(self):
+ features = {'a': [0.]}
+ input_layer = InputLayer(fc_old.numeric_column('a'))
+ inputs = self.evaluate(input_layer(features))
+ self.assertAllClose([[0.]], inputs)
+
+ def test_reuses_variables(self):
+ with context.eager_mode():
+ sparse_input = sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (2, 0)),
+ values=(0, 1, 2),
+ dense_shape=(3, 3))
+
+ # Create feature columns (categorical and embedding).
+ categorical_column = fc_old.categorical_column_with_identity(
+ key='a', num_buckets=3)
+ embedding_dimension = 2
+ def _embedding_column_initializer(shape, dtype, partition_info):
+ del shape # unused
+ del dtype # unused
+ del partition_info # unused
+ embedding_values = (
+ (1, 0), # id 0
+ (0, 1), # id 1
+ (1, 1)) # id 2
+ return embedding_values
+
+ embedding_column = fc_old.embedding_column(
+ categorical_column,
+ dimension=embedding_dimension,
+ initializer=_embedding_column_initializer)
+
+ input_layer = InputLayer([embedding_column])
+ features = {'a': sparse_input}
+
+ inputs = input_layer(features)
+ variables = input_layer.variables
+
+ # Sanity check: test that the inputs are correct.
+ self.assertAllEqual([[1, 0], [0, 1], [1, 1]], inputs)
+
+ # Check that only one variable was created.
+ self.assertEqual(1, len(variables))
+
+ # Check that invoking input_layer on the same features does not create
+ # additional variables
+ _ = input_layer(features)
+ self.assertEqual(1, len(variables))
+ self.assertEqual(variables[0], input_layer.variables[0])
+
+ def test_feature_column_input_layer_gradient(self):
+ with context.eager_mode():
+ sparse_input = sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (2, 0)),
+ values=(0, 1, 2),
+ dense_shape=(3, 3))
+
+ # Create feature columns (categorical and embedding).
+ categorical_column = fc_old.categorical_column_with_identity(
+ key='a', num_buckets=3)
+ embedding_dimension = 2
+
+ def _embedding_column_initializer(shape, dtype, partition_info):
+ del shape # unused
+ del dtype # unused
+ del partition_info # unused
+ embedding_values = (
+ (1, 0), # id 0
+ (0, 1), # id 1
+ (1, 1)) # id 2
+ return embedding_values
+
+ embedding_column = fc_old.embedding_column(
+ categorical_column,
+ dimension=embedding_dimension,
+ initializer=_embedding_column_initializer)
+
+ input_layer = InputLayer([embedding_column])
+ features = {'a': sparse_input}
+
+ def scale_matrix():
+ matrix = input_layer(features)
+ return 2 * matrix
+
+ # Sanity check: Verify that scale_matrix returns the correct output.
+ self.assertAllEqual([[2, 0], [0, 2], [2, 2]], scale_matrix())
+
+ # Check that the returned gradient is correct.
+ grad_function = backprop.implicit_grad(scale_matrix)
+ grads_and_vars = grad_function()
+ indexed_slice = grads_and_vars[0][0]
+ gradient = grads_and_vars[0][0].values
+
+ self.assertAllEqual([0, 1, 2], indexed_slice.indices)
+ self.assertAllEqual([[2, 2], [2, 2], [2, 2]], gradient)
+
+
+class FunctionalInputLayerTest(test.TestCase):
+
+ def test_raises_if_empty_feature_columns(self):
+ with self.assertRaisesRegexp(ValueError,
+ 'feature_columns must not be empty'):
+ fc.input_layer(features={}, feature_columns=[])
+
+ def test_should_be_dense_column(self):
+ with self.assertRaisesRegexp(ValueError, 'must be a _DenseColumn'):
+ fc.input_layer(
+ features={'a': [[0]]},
+ feature_columns=[
+ fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ ])
+
+ def test_does_not_support_dict_columns(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'Expected feature_columns to be iterable, found dict.'):
+ fc.input_layer(
+ features={'a': [[0]]},
+ feature_columns={'a': fc_old.numeric_column('a')})
+
+ def test_bare_column(self):
+ with ops.Graph().as_default():
+ features = features = {'a': [0.]}
+ net = fc.input_layer(features, fc_old.numeric_column('a'))
+ with _initialized_session():
+ self.assertAllClose([[0.]], net.eval())
+
+ def test_column_generator(self):
+ with ops.Graph().as_default():
+ features = features = {'a': [0.], 'b': [1.]}
+ columns = (fc_old.numeric_column(key) for key in features)
+ net = fc.input_layer(features, columns)
+ with _initialized_session():
+ self.assertAllClose([[0., 1.]], net.eval())
+
+ def test_raises_if_duplicate_name(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'Duplicate feature column name found for columns'):
+ fc.input_layer(
+ features={'a': [[0]]},
+ feature_columns=[
+ fc_old.numeric_column('a'),
+ fc_old.numeric_column('a')
+ ])
+
+ def test_one_column(self):
+ price = fc_old.numeric_column('price')
+ with ops.Graph().as_default():
+ features = {'price': [[1.], [5.]]}
+ net = fc.input_layer(features, [price])
+ with _initialized_session():
+ self.assertAllClose([[1.], [5.]], net.eval())
+
+ def test_multi_dimension(self):
+ price = fc_old.numeric_column('price', shape=2)
+ with ops.Graph().as_default():
+ features = {'price': [[1., 2.], [5., 6.]]}
+ net = fc.input_layer(features, [price])
+ with _initialized_session():
+ self.assertAllClose([[1., 2.], [5., 6.]], net.eval())
+
+ def test_raises_if_shape_mismatch(self):
+ price = fc_old.numeric_column('price', shape=2)
+ with ops.Graph().as_default():
+ features = {'price': [[1.], [5.]]}
+ with self.assertRaisesRegexp(
+ Exception,
+ r'Cannot reshape a tensor with 2 elements to shape \[2,2\]'):
+ fc.input_layer(features, [price])
+
+ def test_reshaping(self):
+ price = fc_old.numeric_column('price', shape=[1, 2])
+ with ops.Graph().as_default():
+ features = {'price': [[[1., 2.]], [[5., 6.]]]}
+ net = fc.input_layer(features, [price])
+ with _initialized_session():
+ self.assertAllClose([[1., 2.], [5., 6.]], net.eval())
+
+ def test_multi_column(self):
+ price1 = fc_old.numeric_column('price1', shape=2)
+ price2 = fc_old.numeric_column('price2')
+ with ops.Graph().as_default():
+ features = {
+ 'price1': [[1., 2.], [5., 6.]],
+ 'price2': [[3.], [4.]]
+ }
+ net = fc.input_layer(features, [price1, price2])
+ with _initialized_session():
+ self.assertAllClose([[1., 2., 3.], [5., 6., 4.]], net.eval())
+
+ def test_fills_cols_to_vars(self):
+ # Provide three _DenseColumn's to input_layer: a _NumericColumn, a
+ # _BucketizedColumn, and an _EmbeddingColumn. Only the _EmbeddingColumn
+ # creates a Variable.
+ price1 = fc_old.numeric_column('price1')
+ dense_feature = fc_old.numeric_column('dense_feature')
+ dense_feature_bucketized = fc_old.bucketized_column(
+ dense_feature, boundaries=[0.])
+ some_sparse_column = fc_old.categorical_column_with_hash_bucket(
+ 'sparse_feature', hash_bucket_size=5)
+ some_embedding_column = fc_old.embedding_column(
+ some_sparse_column, dimension=10)
+ with ops.Graph().as_default():
+ features = {
+ 'price1': [[3.], [4.]],
+ 'dense_feature': [[-1.], [4.]],
+ 'sparse_feature': [['a'], ['x']],
+ }
+ cols_to_vars = {}
+ all_cols = [price1, dense_feature_bucketized, some_embedding_column]
+ fc.input_layer(features, all_cols, cols_to_vars=cols_to_vars)
+ self.assertItemsEqual(list(cols_to_vars.keys()), all_cols)
+ self.assertEqual(0, len(cols_to_vars[price1]))
+ self.assertEqual(0, len(cols_to_vars[dense_feature_bucketized]))
+ self.assertEqual(1, len(cols_to_vars[some_embedding_column]))
+ self.assertIsInstance(cols_to_vars[some_embedding_column][0],
+ variables_lib.Variable)
+ self.assertAllEqual(cols_to_vars[some_embedding_column][0].shape, [5, 10])
+
+ def test_fills_cols_to_vars_partitioned_variables(self):
+ price1 = fc_old.numeric_column('price1')
+ dense_feature = fc_old.numeric_column('dense_feature')
+ dense_feature_bucketized = fc_old.bucketized_column(
+ dense_feature, boundaries=[0.])
+ some_sparse_column = fc_old.categorical_column_with_hash_bucket(
+ 'sparse_feature', hash_bucket_size=5)
+ some_embedding_column = fc_old.embedding_column(
+ some_sparse_column, dimension=10)
+ with ops.Graph().as_default():
+ features = {
+ 'price1': [[3.], [4.]],
+ 'dense_feature': [[-1.], [4.]],
+ 'sparse_feature': [['a'], ['x']],
+ }
+ cols_to_vars = {}
+ all_cols = [price1, dense_feature_bucketized, some_embedding_column]
+ with variable_scope.variable_scope(
+ 'input_from_feature_columns',
+ partitioner=partitioned_variables.fixed_size_partitioner(3, axis=0)):
+ fc.input_layer(features, all_cols, cols_to_vars=cols_to_vars)
+ self.assertItemsEqual(list(cols_to_vars.keys()), all_cols)
+ self.assertEqual(0, len(cols_to_vars[price1]))
+ self.assertEqual(0, len(cols_to_vars[dense_feature_bucketized]))
+ self.assertEqual(3, len(cols_to_vars[some_embedding_column]))
+ self.assertAllEqual(cols_to_vars[some_embedding_column][0].shape, [2, 10])
+ self.assertAllEqual(cols_to_vars[some_embedding_column][1].shape, [2, 10])
+ self.assertAllEqual(cols_to_vars[some_embedding_column][2].shape, [1, 10])
+
+ def test_column_order(self):
+ price_a = fc_old.numeric_column('price_a')
+ price_b = fc_old.numeric_column('price_b')
+ with ops.Graph().as_default():
+ features = {
+ 'price_a': [[1.]],
+ 'price_b': [[3.]],
+ }
+ net1 = fc.input_layer(features, [price_a, price_b])
+ net2 = fc.input_layer(features, [price_b, price_a])
+ with _initialized_session():
+ self.assertAllClose([[1., 3.]], net1.eval())
+ self.assertAllClose([[1., 3.]], net2.eval())
+
+ def test_fails_for_categorical_column(self):
+ animal = fc_old.categorical_column_with_identity('animal', num_buckets=4)
+ with ops.Graph().as_default():
+ features = {
+ 'animal':
+ sparse_tensor.SparseTensor(
+ indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2])
+ }
+ with self.assertRaisesRegexp(Exception, 'must be a _DenseColumn'):
+ fc.input_layer(features, [animal])
+
+ def test_static_batch_size_mismatch(self):
+ price1 = fc_old.numeric_column('price1')
+ price2 = fc_old.numeric_column('price2')
+ with ops.Graph().as_default():
+ features = {
+ 'price1': [[1.], [5.], [7.]], # batchsize = 3
+ 'price2': [[3.], [4.]] # batchsize = 2
+ }
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string
+ fc.input_layer(features, [price1, price2])
+
+ def test_subset_of_static_batch_size_mismatch(self):
+ price1 = fc_old.numeric_column('price1')
+ price2 = fc_old.numeric_column('price2')
+ price3 = fc_old.numeric_column('price3')
+ with ops.Graph().as_default():
+ features = {
+ 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3
+ 'price2': [[3.], [4.]], # batchsize = 2
+ 'price3': [[3.], [4.], [5.]] # batchsize = 3
+ }
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string
+ fc.input_layer(features, [price1, price2, price3])
+
+ def test_runtime_batch_size_mismatch(self):
+ price1 = fc_old.numeric_column('price1')
+ price2 = fc_old.numeric_column('price2')
+ with ops.Graph().as_default():
+ features = {
+ 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3
+ 'price2': [[3.], [4.]] # batchsize = 2
+ }
+ net = fc.input_layer(features, [price1, price2])
+ with _initialized_session() as sess:
+ with self.assertRaisesRegexp(errors.OpError,
+ 'Dimensions of inputs should match'):
+ sess.run(net, feed_dict={features['price1']: [[1.], [5.], [7.]]})
+
+ def test_runtime_batch_size_matches(self):
+ price1 = fc_old.numeric_column('price1')
+ price2 = fc_old.numeric_column('price2')
+ with ops.Graph().as_default():
+ features = {
+ 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2
+ 'price2': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2
+ }
+ net = fc.input_layer(features, [price1, price2])
+ with _initialized_session() as sess:
+ sess.run(
+ net,
+ feed_dict={
+ features['price1']: [[1.], [5.]],
+ features['price2']: [[1.], [5.]],
+ })
+
+ def test_multiple_layers_with_same_embedding_column(self):
+ some_sparse_column = fc_old.categorical_column_with_hash_bucket(
+ 'sparse_feature', hash_bucket_size=5)
+ some_embedding_column = fc_old.embedding_column(
+ some_sparse_column, dimension=10)
+
+ with ops.Graph().as_default():
+ features = {
+ 'sparse_feature': [['a'], ['x']],
+ }
+ all_cols = [some_embedding_column]
+ fc.input_layer(features, all_cols)
+ fc.input_layer(features, all_cols)
+ # Make sure that 2 variables get created in this case.
+ self.assertEqual(2, len(
+ ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))
+ expected_var_names = [
+ 'input_layer/sparse_feature_embedding/embedding_weights:0',
+ 'input_layer_1/sparse_feature_embedding/embedding_weights:0'
+ ]
+ self.assertItemsEqual(
+ expected_var_names,
+ [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
+
+ def test_multiple_layers_with_same_shared_embedding_column(self):
+ categorical_column_a = fc_old.categorical_column_with_identity(
+ key='aaa', num_buckets=3)
+ categorical_column_b = fc_old.categorical_column_with_identity(
+ key='bbb', num_buckets=3)
+ embedding_dimension = 2
+ embedding_column_b, embedding_column_a = fc_old.shared_embedding_columns(
+ [categorical_column_b, categorical_column_a],
+ dimension=embedding_dimension)
+
+ with ops.Graph().as_default():
+ features = {
+ 'aaa':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 1, 0),
+ dense_shape=(2, 2)),
+ 'bbb':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(1, 2, 1),
+ dense_shape=(2, 2)),
+ }
+ all_cols = [embedding_column_a, embedding_column_b]
+ fc.input_layer(features, all_cols)
+ fc.input_layer(features, all_cols)
+ # Make sure that only 1 variable gets created in this case.
+ self.assertEqual(1, len(
+ ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))
+ self.assertItemsEqual(
+ ['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'],
+ [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
+
+ def test_multiple_layers_with_same_shared_embedding_column_diff_graphs(self):
+ categorical_column_a = fc_old.categorical_column_with_identity(
+ key='aaa', num_buckets=3)
+ categorical_column_b = fc_old.categorical_column_with_identity(
+ key='bbb', num_buckets=3)
+ embedding_dimension = 2
+ embedding_column_b, embedding_column_a = fc_old.shared_embedding_columns(
+ [categorical_column_b, categorical_column_a],
+ dimension=embedding_dimension)
+ all_cols = [embedding_column_a, embedding_column_b]
+
+ with ops.Graph().as_default():
+ features = {
+ 'aaa':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 1, 0),
+ dense_shape=(2, 2)),
+ 'bbb':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(1, 2, 1),
+ dense_shape=(2, 2)),
+ }
+ fc.input_layer(features, all_cols)
+ # Make sure that only 1 variable gets created in this case.
+ self.assertEqual(1, len(
+ ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))
+
+ with ops.Graph().as_default():
+ features1 = {
+ 'aaa':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 1, 0),
+ dense_shape=(2, 2)),
+ 'bbb':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(1, 2, 1),
+ dense_shape=(2, 2)),
+ }
+
+ fc.input_layer(features1, all_cols)
+ # Make sure that only 1 variable gets created in this case.
+ self.assertEqual(1, len(
+ ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))
+ self.assertItemsEqual(
+ ['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'],
+ [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
+
+ def test_with_numpy_input_fn(self):
+ embedding_values = (
+ (1., 2., 3., 4., 5.), # id 0
+ (6., 7., 8., 9., 10.), # id 1
+ (11., 12., 13., 14., 15.) # id 2
+ )
+ def _initializer(shape, dtype, partition_info):
+ del shape, dtype, partition_info
+ return embedding_values
+
+ # price has 1 dimension in input_layer
+ price = fc_old.numeric_column('price')
+ body_style = fc_old.categorical_column_with_vocabulary_list(
+ 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
+ # one_hot_body_style has 3 dims in input_layer.
+ one_hot_body_style = fc_old.indicator_column(body_style)
+ # embedded_body_style has 5 dims in input_layer.
+ embedded_body_style = fc_old.embedding_column(
+ body_style, dimension=5, initializer=_initializer)
+
+ input_fn = numpy_io.numpy_input_fn(
+ x={
+ 'price': np.array([11., 12., 13., 14.]),
+ 'body-style': np.array(['sedan', 'hardtop', 'wagon', 'sedan']),
+ },
+ batch_size=2,
+ shuffle=False)
+ features = input_fn()
+ net = fc.input_layer(features,
+ [price, one_hot_body_style, embedded_body_style])
+ self.assertEqual(1 + 3 + 5, net.shape[1])
+ with _initialized_session() as sess:
+ coord = coordinator.Coordinator()
+ threads = queue_runner_impl.start_queue_runners(sess, coord=coord)
+
+ # Each row is formed by concatenating `embedded_body_style`,
+ # `one_hot_body_style`, and `price` in order.
+ self.assertAllEqual(
+ [[11., 12., 13., 14., 15., 0., 0., 1., 11.],
+ [1., 2., 3., 4., 5., 1., 0., 0., 12]],
+ sess.run(net))
+
+ coord.request_stop()
+ coord.join(threads)
+
+ def test_with_1d_sparse_tensor(self):
+ embedding_values = (
+ (1., 2., 3., 4., 5.), # id 0
+ (6., 7., 8., 9., 10.), # id 1
+ (11., 12., 13., 14., 15.) # id 2
+ )
+ def _initializer(shape, dtype, partition_info):
+ del shape, dtype, partition_info
+ return embedding_values
+
+ # price has 1 dimension in input_layer
+ price = fc_old.numeric_column('price')
+
+ # one_hot_body_style has 3 dims in input_layer.
+ body_style = fc_old.categorical_column_with_vocabulary_list(
+ 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
+ one_hot_body_style = fc_old.indicator_column(body_style)
+
+ # embedded_body_style has 5 dims in input_layer.
+ country = fc_old.categorical_column_with_vocabulary_list(
+ 'country', vocabulary_list=['US', 'JP', 'CA'])
+ embedded_country = fc_old.embedding_column(
+ country, dimension=5, initializer=_initializer)
+
+ # Provides 1-dim tensor and dense tensor.
+ features = {
+ 'price': constant_op.constant([11., 12.,]),
+ 'body-style': sparse_tensor.SparseTensor(
+ indices=((0,), (1,)),
+ values=('sedan', 'hardtop'),
+ dense_shape=(2,)),
+ # This is dense tensor for the categorical_column.
+ 'country': constant_op.constant(['CA', 'US']),
+ }
+ self.assertEqual(1, features['price'].shape.ndims)
+ self.assertEqual(1, features['body-style'].dense_shape.get_shape()[0])
+ self.assertEqual(1, features['country'].shape.ndims)
+
+ net = fc.input_layer(features,
+ [price, one_hot_body_style, embedded_country])
+ self.assertEqual(1 + 3 + 5, net.shape[1])
+ with _initialized_session() as sess:
+
+ # Each row is formed by concatenating `embedded_body_style`,
+ # `one_hot_body_style`, and `price` in order.
+ self.assertAllEqual(
+ [[0., 0., 1., 11., 12., 13., 14., 15., 11.],
+ [1., 0., 0., 1., 2., 3., 4., 5., 12.]],
+ sess.run(net))
+
+ def test_with_1d_unknown_shape_sparse_tensor(self):
+ embedding_values = (
+ (1., 2.), # id 0
+ (6., 7.), # id 1
+ (11., 12.) # id 2
+ )
+ def _initializer(shape, dtype, partition_info):
+ del shape, dtype, partition_info
+ return embedding_values
+
+ # price has 1 dimension in input_layer
+ price = fc_old.numeric_column('price')
+
+ # one_hot_body_style has 3 dims in input_layer.
+ body_style = fc_old.categorical_column_with_vocabulary_list(
+ 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
+ one_hot_body_style = fc_old.indicator_column(body_style)
+
+ # embedded_body_style has 5 dims in input_layer.
+ country = fc_old.categorical_column_with_vocabulary_list(
+ 'country', vocabulary_list=['US', 'JP', 'CA'])
+ embedded_country = fc_old.embedding_column(
+ country, dimension=2, initializer=_initializer)
+
+ # Provides 1-dim tensor and dense tensor.
+ features = {
+ 'price': array_ops.placeholder(dtypes.float32),
+ 'body-style': array_ops.sparse_placeholder(dtypes.string),
+ # This is dense tensor for the categorical_column.
+ 'country': array_ops.placeholder(dtypes.string),
+ }
+ self.assertIsNone(features['price'].shape.ndims)
+ self.assertIsNone(features['body-style'].get_shape().ndims)
+ self.assertIsNone(features['country'].shape.ndims)
+
+ price_data = np.array([11., 12.])
+ body_style_data = sparse_tensor.SparseTensorValue(
+ indices=((0,), (1,)),
+ values=('sedan', 'hardtop'),
+ dense_shape=(2,))
+ country_data = np.array([['US'], ['CA']])
+
+ net = fc.input_layer(features,
+ [price, one_hot_body_style, embedded_country])
+ self.assertEqual(1 + 3 + 2, net.shape[1])
+ with _initialized_session() as sess:
+
+ # Each row is formed by concatenating `embedded_body_style`,
+ # `one_hot_body_style`, and `price` in order.
+ self.assertAllEqual(
+ [[0., 0., 1., 1., 2., 11.], [1., 0., 0., 11., 12., 12.]],
+ sess.run(
+ net,
+ feed_dict={
+ features['price']: price_data,
+ features['body-style']: body_style_data,
+ features['country']: country_data
+ }))
+
+ def test_with_rank_0_feature(self):
+ # price has 1 dimension in input_layer
+ price = fc_old.numeric_column('price')
+ features = {
+ 'price': constant_op.constant(0),
+ }
+ self.assertEqual(0, features['price'].shape.ndims)
+
+ # Static rank 0 should fail
+ with self.assertRaisesRegexp(ValueError, 'Feature .* cannot have rank 0'):
+ fc.input_layer(features, [price])
+
+ # Dynamic rank 0 should fail
+ features = {
+ 'price': array_ops.placeholder(dtypes.float32),
+ }
+ net = fc.input_layer(features, [price])
+ self.assertEqual(1, net.shape[1])
+ with _initialized_session() as sess:
+ with self.assertRaisesOpError('Feature .* cannot have rank 0'):
+ sess.run(net, feed_dict={features['price']: np.array(1)})
+
+
+class MakeParseExampleSpecTest(test.TestCase):
+
+ class _TestFeatureColumn(FeatureColumn,
+ collections.namedtuple('_TestFeatureColumn',
+ ('parse_spec'))):
+
+ @property
+ def name(self):
+ return "_TestFeatureColumn"
+
+ def transform_feature(self, transformation_cache, state_manager):
+ pass
+
+ @property
+ def parse_example_spec(self):
+ return self.parse_spec
+
+ def test_no_feature_columns(self):
+ actual = fc.make_parse_example_spec([])
+ self.assertDictEqual({}, actual)
+
+ def test_invalid_type(self):
+ key1 = 'key1'
+ parse_spec1 = parsing_ops.FixedLenFeature(
+ shape=(2,), dtype=dtypes.float32, default_value=0.)
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'All feature_columns must be FeatureColumn instances.*invalid_column'):
+ fc.make_parse_example_spec(
+ (self._TestFeatureColumn({key1: parse_spec1}), 'invalid_column'))
+
+ def test_one_feature_column(self):
+ key1 = 'key1'
+ parse_spec1 = parsing_ops.FixedLenFeature(
+ shape=(2,), dtype=dtypes.float32, default_value=0.)
+ actual = fc.make_parse_example_spec(
+ (self._TestFeatureColumn({key1: parse_spec1}),))
+ self.assertDictEqual({key1: parse_spec1}, actual)
+
+ def test_two_feature_columns(self):
+ key1 = 'key1'
+ parse_spec1 = parsing_ops.FixedLenFeature(
+ shape=(2,), dtype=dtypes.float32, default_value=0.)
+ key2 = 'key2'
+ parse_spec2 = parsing_ops.VarLenFeature(dtype=dtypes.string)
+ actual = fc.make_parse_example_spec(
+ (self._TestFeatureColumn({key1: parse_spec1}),
+ self._TestFeatureColumn({key2: parse_spec2})))
+ self.assertDictEqual({key1: parse_spec1, key2: parse_spec2}, actual)
+
+ def test_equal_keys_different_parse_spec(self):
+ key1 = 'key1'
+ parse_spec1 = parsing_ops.FixedLenFeature(
+ shape=(2,), dtype=dtypes.float32, default_value=0.)
+ parse_spec2 = parsing_ops.VarLenFeature(dtype=dtypes.string)
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'feature_columns contain different parse_spec for key key1'):
+ fc.make_parse_example_spec(
+ (self._TestFeatureColumn({key1: parse_spec1}),
+ self._TestFeatureColumn({key1: parse_spec2})))
+
+ def test_equal_keys_equal_parse_spec(self):
+ key1 = 'key1'
+ parse_spec1 = parsing_ops.FixedLenFeature(
+ shape=(2,), dtype=dtypes.float32, default_value=0.)
+ actual = fc.make_parse_example_spec(
+ (self._TestFeatureColumn({key1: parse_spec1}),
+ self._TestFeatureColumn({key1: parse_spec1})))
+ self.assertDictEqual({key1: parse_spec1}, actual)
+
+ def test_multiple_features_dict(self):
+ """parse_spc for one column is a dict with length > 1."""
+ key1 = 'key1'
+ parse_spec1 = parsing_ops.FixedLenFeature(
+ shape=(2,), dtype=dtypes.float32, default_value=0.)
+ key2 = 'key2'
+ parse_spec2 = parsing_ops.VarLenFeature(dtype=dtypes.string)
+ key3 = 'key3'
+ parse_spec3 = parsing_ops.VarLenFeature(dtype=dtypes.int32)
+ actual = fc.make_parse_example_spec(
+ (self._TestFeatureColumn({key1: parse_spec1}),
+ self._TestFeatureColumn({key2: parse_spec2, key3: parse_spec3})))
+ self.assertDictEqual(
+ {key1: parse_spec1, key2: parse_spec2, key3: parse_spec3}, actual)
+
+
+def _assert_sparse_tensor_value(test_case, expected, actual):
+ test_case.assertEqual(np.int64, np.array(actual.indices).dtype)
+ test_case.assertAllEqual(expected.indices, actual.indices)
+
+ test_case.assertEqual(
+ np.array(expected.values).dtype, np.array(actual.values).dtype)
+ test_case.assertAllEqual(expected.values, actual.values)
+
+ test_case.assertEqual(np.int64, np.array(actual.dense_shape).dtype)
+ test_case.assertAllEqual(expected.dense_shape, actual.dense_shape)
+
+
+class VocabularyFileCategoricalColumnTest(test.TestCase):
+
+ def setUp(self):
+ super(VocabularyFileCategoricalColumnTest, self).setUp()
+
+ # Contains ints, Golden State Warriors jersey numbers: 30, 35, 11, 23, 22
+ self._warriors_vocabulary_file_name = test.test_src_dir_path(
+ 'python/feature_column/testdata/warriors_vocabulary.txt')
+ self._warriors_vocabulary_size = 5
+
+ # Contains strings, character names from 'The Wire': omar, stringer, marlo
+ self._wire_vocabulary_file_name = test.test_src_dir_path(
+ 'python/feature_column/testdata/wire_vocabulary.txt')
+ self._wire_vocabulary_size = 3
+
+ def test_defaults(self):
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa', vocabulary_file='path_to_file', vocabulary_size=3)
+ self.assertEqual('aaa', column.name)
+ self.assertEqual('aaa', column.key)
+ self.assertEqual(3, column.num_buckets)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.string)
+ }, column.parse_example_spec)
+
+ def test_key_should_be_string(self):
+ with self.assertRaisesRegexp(ValueError, 'key must be a string.'):
+ fc.categorical_column_with_vocabulary_file(
+ key=('aaa',), vocabulary_file='path_to_file', vocabulary_size=3)
+
+ def test_all_constructor_args(self):
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa', vocabulary_file='path_to_file', vocabulary_size=3,
+ num_oov_buckets=4, dtype=dtypes.int32)
+ self.assertEqual(7, column.num_buckets)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.int32)
+ }, column.parse_example_spec)
+
+ def test_deep_copy(self):
+ original = fc.categorical_column_with_vocabulary_file(
+ key='aaa', vocabulary_file='path_to_file', vocabulary_size=3,
+ num_oov_buckets=4, dtype=dtypes.int32)
+ for column in (original, copy.deepcopy(original)):
+ self.assertEqual('aaa', column.name)
+ self.assertEqual(7, column.num_buckets)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.int32)
+ }, column.parse_example_spec)
+
+ def test_vocabulary_file_none(self):
+ with self.assertRaisesRegexp(ValueError, 'Missing vocabulary_file'):
+ fc.categorical_column_with_vocabulary_file(
+ key='aaa', vocabulary_file=None, vocabulary_size=3)
+
+ def test_vocabulary_file_empty_string(self):
+ with self.assertRaisesRegexp(ValueError, 'Missing vocabulary_file'):
+ fc.categorical_column_with_vocabulary_file(
+ key='aaa', vocabulary_file='', vocabulary_size=3)
+
+ def test_invalid_vocabulary_file(self):
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa', vocabulary_file='file_does_not_exist', vocabulary_size=10)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ column.get_sparse_tensors(FeatureTransformationCache({'aaa': inputs}), None)
+ with self.assertRaisesRegexp(errors.OpError, 'file_does_not_exist'):
+ with self.test_session():
+ lookup_ops.tables_initializer().run()
+
+ def test_invalid_vocabulary_size(self):
+ with self.assertRaisesRegexp(ValueError, 'Invalid vocabulary_size'):
+ fc.categorical_column_with_vocabulary_file(
+ key='aaa', vocabulary_file=self._wire_vocabulary_file_name,
+ vocabulary_size=-1)
+ with self.assertRaisesRegexp(ValueError, 'Invalid vocabulary_size'):
+ fc.categorical_column_with_vocabulary_file(
+ key='aaa', vocabulary_file=self._wire_vocabulary_file_name,
+ vocabulary_size=0)
+
+ def test_too_large_vocabulary_size(self):
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa',
+ vocabulary_file=self._wire_vocabulary_file_name,
+ vocabulary_size=self._wire_vocabulary_size + 1)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ column.get_sparse_tensors(FeatureTransformationCache({'aaa': inputs}), None)
+ with self.assertRaisesRegexp(errors.OpError, 'Invalid vocab_size'):
+ with self.test_session():
+ lookup_ops.tables_initializer().run()
+
+ def test_invalid_num_oov_buckets(self):
+ with self.assertRaisesRegexp(ValueError, 'Invalid num_oov_buckets'):
+ fc.categorical_column_with_vocabulary_file(
+ key='aaa', vocabulary_file='path', vocabulary_size=3,
+ num_oov_buckets=-1)
+
+ def test_invalid_dtype(self):
+ with self.assertRaisesRegexp(ValueError, 'dtype must be string or integer'):
+ fc.categorical_column_with_vocabulary_file(
+ key='aaa', vocabulary_file='path', vocabulary_size=3,
+ dtype=dtypes.float64)
+
+ def test_invalid_buckets_and_default_value(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'both num_oov_buckets and default_value'):
+ fc.categorical_column_with_vocabulary_file(
+ key='aaa',
+ vocabulary_file=self._wire_vocabulary_file_name,
+ vocabulary_size=self._wire_vocabulary_size,
+ num_oov_buckets=100,
+ default_value=2)
+
+ def test_invalid_input_dtype_int32(self):
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa',
+ vocabulary_file=self._wire_vocabulary_file_name,
+ vocabulary_size=self._wire_vocabulary_size,
+ dtype=dtypes.string)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(12, 24, 36),
+ dense_shape=(2, 2))
+ with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'):
+ column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }), None)
+
+ def test_invalid_input_dtype_string(self):
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa',
+ vocabulary_file=self._warriors_vocabulary_file_name,
+ vocabulary_size=self._warriors_vocabulary_size,
+ dtype=dtypes.int32)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('omar', 'stringer', 'marlo'),
+ dense_shape=(2, 2))
+ with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'):
+ column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }), None)
+
+ def test_parse_example(self):
+ a = fc.categorical_column_with_vocabulary_file(
+ key='aaa', vocabulary_file='path_to_file', vocabulary_size=3)
+ data = example_pb2.Example(features=feature_pb2.Features(
+ feature={
+ 'aaa':
+ feature_pb2.Feature(bytes_list=feature_pb2.BytesList(
+ value=[b'omar', b'stringer']))
+ }))
+ features = parsing_ops.parse_example(
+ serialized=[data.SerializeToString()],
+ features=fc.make_parse_example_spec([a]))
+ self.assertIn('aaa', features)
+ with self.test_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [0, 1]],
+ values=np.array([b'omar', b'stringer'], dtype=np.object_),
+ dense_shape=[1, 2]),
+ features['aaa'].eval())
+
+ def test_get_sparse_tensors(self):
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa',
+ vocabulary_file=self._wire_vocabulary_file_name,
+ vocabulary_size=self._wire_vocabulary_size)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ id_weight_pair = column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }), None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array((2, -1, 0), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_weight_pair.id_tensor.eval())
+
+ def test_get_sparse_tensors_none_vocabulary_size(self):
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa', vocabulary_file=self._wire_vocabulary_file_name)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ id_weight_pair = column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }), None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array(
+ (2, -1, 0), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_weight_pair.id_tensor.eval())
+
+ def test_transform_feature(self):
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa',
+ vocabulary_file=self._wire_vocabulary_file_name,
+ vocabulary_size=self._wire_vocabulary_size)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ id_tensor = _transform_features({'aaa': inputs}, [column], None)[column]
+ with _initialized_session():
+ _assert_sparse_tensor_value(self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array(
+ (2, -1, 0), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_tensor.eval())
+
+ def DISABLED_test_get_sparse_tensors_weight_collections(self):
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa',
+ vocabulary_file=self._wire_vocabulary_file_name,
+ vocabulary_size=self._wire_vocabulary_size)
+ inputs = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'],
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }),
+ weight_collections=('my_weights',))
+
+ self.assertItemsEqual(
+ [], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
+ self.assertItemsEqual([], ops.get_collection('my_weights'))
+
+ def test_get_sparse_tensors_dense_input(self):
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa',
+ vocabulary_file=self._wire_vocabulary_file_name,
+ vocabulary_size=self._wire_vocabulary_size)
+ id_weight_pair = column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': (('marlo', ''), ('skywalker', 'omar'))
+ }), None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=np.array((2, -1, 0), dtype=np.int64),
+ dense_shape=(2, 2)),
+ id_weight_pair.id_tensor.eval())
+
+ def test_get_sparse_tensors_default_value_in_vocabulary(self):
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa',
+ vocabulary_file=self._wire_vocabulary_file_name,
+ vocabulary_size=self._wire_vocabulary_size,
+ default_value=2)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ id_weight_pair = column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }), None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array((2, 2, 0), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_weight_pair.id_tensor.eval())
+
+ def test_get_sparse_tensors_with_oov_buckets(self):
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa',
+ vocabulary_file=self._wire_vocabulary_file_name,
+ vocabulary_size=self._wire_vocabulary_size,
+ num_oov_buckets=100)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1), (1, 2)),
+ values=('marlo', 'skywalker', 'omar', 'heisenberg'),
+ dense_shape=(2, 3))
+ id_weight_pair = column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }), None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array((2, 33, 0, 62), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_weight_pair.id_tensor.eval())
+
+ def test_get_sparse_tensors_small_vocabulary_size(self):
+ # 'marlo' is the last entry in our vocabulary file, so be setting
+ # `vocabulary_size` to 1 less than number of entries in file, we take
+ # 'marlo' out of the vocabulary.
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa',
+ vocabulary_file=self._wire_vocabulary_file_name,
+ vocabulary_size=self._wire_vocabulary_size - 1)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ id_weight_pair = column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }), None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array((-1, -1, 0), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_weight_pair.id_tensor.eval())
+
+ def test_get_sparse_tensors_int32(self):
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa',
+ vocabulary_file=self._warriors_vocabulary_file_name,
+ vocabulary_size=self._warriors_vocabulary_size,
+ dtype=dtypes.int32)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1), (2, 2)),
+ values=(11, 100, 30, 22),
+ dense_shape=(3, 3))
+ id_weight_pair = column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }), None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array((2, -1, 0, 4), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_weight_pair.id_tensor.eval())
+
+ def test_get_sparse_tensors_int32_dense_input(self):
+ default_value = -100
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa',
+ vocabulary_file=self._warriors_vocabulary_file_name,
+ vocabulary_size=self._warriors_vocabulary_size,
+ dtype=dtypes.int32,
+ default_value=default_value)
+ id_weight_pair = column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': ((11, -1, -1), (100, 30, -1), (-1, -1, 22))
+ }), None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1), (2, 2)),
+ values=np.array((2, default_value, 0, 4), dtype=np.int64),
+ dense_shape=(3, 3)),
+ id_weight_pair.id_tensor.eval())
+
+ def test_get_sparse_tensors_int32_with_oov_buckets(self):
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa',
+ vocabulary_file=self._warriors_vocabulary_file_name,
+ vocabulary_size=self._warriors_vocabulary_size,
+ dtype=dtypes.int32,
+ num_oov_buckets=100)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1), (2, 2)),
+ values=(11, 100, 30, 22),
+ dense_shape=(3, 3))
+ id_weight_pair = column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }), None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array((2, 60, 0, 4), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_weight_pair.id_tensor.eval())
+
+ def test_linear_model(self):
+ wire_column = fc_old.categorical_column_with_vocabulary_file(
+ key='wire',
+ vocabulary_file=self._wire_vocabulary_file_name,
+ vocabulary_size=self._wire_vocabulary_size,
+ num_oov_buckets=1)
+ self.assertEqual(4, wire_column._num_buckets)
+ with ops.Graph().as_default():
+ predictions = fc.linear_model({
+ wire_column.name: sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ }, (wire_column,))
+ bias = get_linear_model_bias()
+ wire_var = get_linear_model_column_var(wire_column)
+ with _initialized_session():
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval()
+ # 'marlo' -> 2: wire_var[2] = 3
+ # 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5
+ self.assertAllClose(((3.,), (5.,)), predictions.eval())
+
+ def test_keras_linear_model(self):
+ wire_column = fc_old.categorical_column_with_vocabulary_file(
+ key='wire',
+ vocabulary_file=self._wire_vocabulary_file_name,
+ vocabulary_size=self._wire_vocabulary_size,
+ num_oov_buckets=1)
+ self.assertEqual(4, wire_column._num_buckets)
+ with ops.Graph().as_default():
+ predictions = get_keras_linear_model_predictions({
+ wire_column.name:
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ }, (wire_column,))
+ bias = get_linear_model_bias()
+ wire_var = get_linear_model_column_var(wire_column)
+ with _initialized_session():
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval()
+ # 'marlo' -> 2: wire_var[2] = 3
+ # 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5
+ self.assertAllClose(((3.,), (5.,)), predictions.eval())
+
+
+class VocabularyListCategoricalColumnTest(test.TestCase):
+
+ def test_defaults_string(self):
+ column = fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
+ self.assertEqual('aaa', column.name)
+ self.assertEqual('aaa', column.key)
+ self.assertEqual(3, column.num_buckets)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.string)
+ }, column.parse_example_spec)
+
+ def test_key_should_be_string(self):
+ with self.assertRaisesRegexp(ValueError, 'key must be a string.'):
+ fc.categorical_column_with_vocabulary_list(
+ key=('aaa',), vocabulary_list=('omar', 'stringer', 'marlo'))
+
+ def test_defaults_int(self):
+ column = fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=(12, 24, 36))
+ self.assertEqual('aaa', column.name)
+ self.assertEqual('aaa', column.key)
+ self.assertEqual(3, column.num_buckets)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.int64)
+ }, column.parse_example_spec)
+
+ def test_all_constructor_args(self):
+ column = fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=(12, 24, 36), dtype=dtypes.int32,
+ default_value=-99)
+ self.assertEqual(3, column.num_buckets)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.int32)
+ }, column.parse_example_spec)
+
+ def test_deep_copy(self):
+ original = fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=(12, 24, 36), dtype=dtypes.int32)
+ for column in (original, copy.deepcopy(original)):
+ self.assertEqual('aaa', column.name)
+ self.assertEqual(3, column.num_buckets)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.int32)
+ }, column.parse_example_spec)
+
+ def test_invalid_dtype(self):
+ with self.assertRaisesRegexp(ValueError, 'dtype must be string or integer'):
+ fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'),
+ dtype=dtypes.float32)
+
+ def test_invalid_mapping_dtype(self):
+ with self.assertRaisesRegexp(
+ ValueError, r'vocabulary dtype must be string or integer'):
+ fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=(12., 24., 36.))
+
+ def test_mismatched_int_dtype(self):
+ with self.assertRaisesRegexp(
+ ValueError, r'dtype.*and vocabulary dtype.*do not match'):
+ fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'),
+ dtype=dtypes.int32)
+
+ def test_mismatched_string_dtype(self):
+ with self.assertRaisesRegexp(
+ ValueError, r'dtype.*and vocabulary dtype.*do not match'):
+ fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=(12, 24, 36), dtype=dtypes.string)
+
+ def test_none_mapping(self):
+ with self.assertRaisesRegexp(
+ ValueError, r'vocabulary_list.*must be non-empty'):
+ fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=None)
+
+ def test_empty_mapping(self):
+ with self.assertRaisesRegexp(
+ ValueError, r'vocabulary_list.*must be non-empty'):
+ fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=tuple([]))
+
+ def test_duplicate_mapping(self):
+ with self.assertRaisesRegexp(ValueError, 'Duplicate keys'):
+ fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=(12, 24, 12))
+
+ def test_invalid_num_oov_buckets(self):
+ with self.assertRaisesRegexp(ValueError, 'Invalid num_oov_buckets'):
+ fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=(12, 24, 36),
+ num_oov_buckets=-1)
+
+ def test_invalid_buckets_and_default_value(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'both num_oov_buckets and default_value'):
+ fc.categorical_column_with_vocabulary_list(
+ key='aaa',
+ vocabulary_list=(12, 24, 36),
+ num_oov_buckets=100,
+ default_value=2)
+
+ def test_invalid_input_dtype_int32(self):
+ column = fc.categorical_column_with_vocabulary_list(
+ key='aaa',
+ vocabulary_list=('omar', 'stringer', 'marlo'))
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(12, 24, 36),
+ dense_shape=(2, 2))
+ with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'):
+ column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }), None)
+
+ def test_invalid_input_dtype_string(self):
+ column = fc.categorical_column_with_vocabulary_list(
+ key='aaa',
+ vocabulary_list=(12, 24, 36))
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('omar', 'stringer', 'marlo'),
+ dense_shape=(2, 2))
+ with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'):
+ column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }), None)
+
+ def test_parse_example_string(self):
+ a = fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
+ data = example_pb2.Example(features=feature_pb2.Features(
+ feature={
+ 'aaa':
+ feature_pb2.Feature(bytes_list=feature_pb2.BytesList(
+ value=[b'omar', b'stringer']))
+ }))
+ features = parsing_ops.parse_example(
+ serialized=[data.SerializeToString()],
+ features=fc.make_parse_example_spec([a]))
+ self.assertIn('aaa', features)
+ with self.test_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [0, 1]],
+ values=np.array([b'omar', b'stringer'], dtype=np.object_),
+ dense_shape=[1, 2]),
+ features['aaa'].eval())
+
+ def test_parse_example_int(self):
+ a = fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=(11, 21, 31))
+ data = example_pb2.Example(features=feature_pb2.Features(
+ feature={
+ 'aaa':
+ feature_pb2.Feature(int64_list=feature_pb2.Int64List(
+ value=[11, 21]))
+ }))
+ features = parsing_ops.parse_example(
+ serialized=[data.SerializeToString()],
+ features=fc.make_parse_example_spec([a]))
+ self.assertIn('aaa', features)
+ with self.test_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [0, 1]],
+ values=[11, 21],
+ dense_shape=[1, 2]),
+ features['aaa'].eval())
+
+ def test_get_sparse_tensors(self):
+ column = fc.categorical_column_with_vocabulary_list(
+ key='aaa',
+ vocabulary_list=('omar', 'stringer', 'marlo'))
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ id_weight_pair = column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }), None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array((2, -1, 0), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_weight_pair.id_tensor.eval())
+
+ def test_transform_feature(self):
+ column = fc.categorical_column_with_vocabulary_list(
+ key='aaa',
+ vocabulary_list=('omar', 'stringer', 'marlo'))
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ id_tensor = _transform_features({'aaa': inputs}, [column], None)[column]
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array((2, -1, 0), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_tensor.eval())
+
+ def DISABLED_test_get_sparse_tensors_weight_collections(self):
+ column = fc.categorical_column_with_vocabulary_list(
+ key='aaa',
+ vocabulary_list=('omar', 'stringer', 'marlo'))
+ inputs = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'],
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }),
+ weight_collections=('my_weights',))
+
+ self.assertItemsEqual(
+ [], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
+ self.assertItemsEqual([], ops.get_collection('my_weights'))
+
+ def test_get_sparse_tensors_dense_input(self):
+ column = fc.categorical_column_with_vocabulary_list(
+ key='aaa',
+ vocabulary_list=('omar', 'stringer', 'marlo'))
+ id_weight_pair = column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': (('marlo', ''), ('skywalker', 'omar'))
+ }), None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=np.array((2, -1, 0), dtype=np.int64),
+ dense_shape=(2, 2)),
+ id_weight_pair.id_tensor.eval())
+
+ def test_get_sparse_tensors_default_value_in_vocabulary(self):
+ column = fc.categorical_column_with_vocabulary_list(
+ key='aaa',
+ vocabulary_list=('omar', 'stringer', 'marlo'),
+ default_value=2)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ id_weight_pair = column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }), None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array((2, 2, 0), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_weight_pair.id_tensor.eval())
+
+ def test_get_sparse_tensors_with_oov_buckets(self):
+ column = fc.categorical_column_with_vocabulary_list(
+ key='aaa',
+ vocabulary_list=('omar', 'stringer', 'marlo'),
+ num_oov_buckets=100)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1), (1, 2)),
+ values=('marlo', 'skywalker', 'omar', 'heisenberg'),
+ dense_shape=(2, 3))
+ id_weight_pair = column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }), None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array((2, 33, 0, 62), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_weight_pair.id_tensor.eval())
+
+ def test_get_sparse_tensors_int32(self):
+ column = fc.categorical_column_with_vocabulary_list(
+ key='aaa',
+ vocabulary_list=np.array((30, 35, 11, 23, 22), dtype=np.int32),
+ dtype=dtypes.int32)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1), (2, 2)),
+ values=np.array((11, 100, 30, 22), dtype=np.int32),
+ dense_shape=(3, 3))
+ id_weight_pair = column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }), None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array((2, -1, 0, 4), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_weight_pair.id_tensor.eval())
+
+ def test_get_sparse_tensors_int32_dense_input(self):
+ default_value = -100
+ column = fc.categorical_column_with_vocabulary_list(
+ key='aaa',
+ vocabulary_list=np.array((30, 35, 11, 23, 22), dtype=np.int32),
+ dtype=dtypes.int32,
+ default_value=default_value)
+ id_weight_pair = column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa':
+ np.array(
+ ((11, -1, -1), (100, 30, -1), (-1, -1, 22)), dtype=np.int32)
+ }), None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1), (2, 2)),
+ values=np.array((2, default_value, 0, 4), dtype=np.int64),
+ dense_shape=(3, 3)),
+ id_weight_pair.id_tensor.eval())
+
+ def test_get_sparse_tensors_int32_with_oov_buckets(self):
+ column = fc.categorical_column_with_vocabulary_list(
+ key='aaa',
+ vocabulary_list=np.array((30, 35, 11, 23, 22), dtype=np.int32),
+ dtype=dtypes.int32,
+ num_oov_buckets=100)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1), (2, 2)),
+ values=(11, 100, 30, 22),
+ dense_shape=(3, 3))
+ id_weight_pair = column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }), None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array((2, 60, 0, 4), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_weight_pair.id_tensor.eval())
+
+ def test_linear_model(self):
+ wire_column = fc_old.categorical_column_with_vocabulary_list(
+ key='aaa',
+ vocabulary_list=('omar', 'stringer', 'marlo'),
+ num_oov_buckets=1)
+ self.assertEqual(4, wire_column._num_buckets)
+ with ops.Graph().as_default():
+ predictions = fc.linear_model({
+ wire_column.name: sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ }, (wire_column,))
+ bias = get_linear_model_bias()
+ wire_var = get_linear_model_column_var(wire_column)
+ with _initialized_session():
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval()
+ # 'marlo' -> 2: wire_var[2] = 3
+ # 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5
+ self.assertAllClose(((3.,), (5.,)), predictions.eval())
+
+ def test_keras_linear_model(self):
+ wire_column = fc_old.categorical_column_with_vocabulary_list(
+ key='aaa',
+ vocabulary_list=('omar', 'stringer', 'marlo'),
+ num_oov_buckets=1)
+ self.assertEqual(4, wire_column._num_buckets)
+ with ops.Graph().as_default():
+ predictions = get_keras_linear_model_predictions({
+ wire_column.name:
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ }, (wire_column,))
+ bias = get_linear_model_bias()
+ wire_var = get_linear_model_column_var(wire_column)
+ with _initialized_session():
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval()
+ # 'marlo' -> 2: wire_var[2] = 3
+ # 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5
+ self.assertAllClose(((3.,), (5.,)), predictions.eval())
+
+
+class IdentityCategoricalColumnTest(test.TestCase):
+
+ def test_constructor(self):
+ column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
+ self.assertEqual('aaa', column.name)
+ self.assertEqual('aaa', column.key)
+ self.assertEqual(3, column.num_buckets)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.int64)
+ }, column.parse_example_spec)
+
+ def test_key_should_be_string(self):
+ with self.assertRaisesRegexp(ValueError, 'key must be a string.'):
+ fc.categorical_column_with_identity(key=('aaa',), num_buckets=3)
+
+ def test_deep_copy(self):
+ original = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
+ for column in (original, copy.deepcopy(original)):
+ self.assertEqual('aaa', column.name)
+ self.assertEqual(3, column.num_buckets)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.int64)
+ }, column.parse_example_spec)
+
+ def test_invalid_num_buckets_zero(self):
+ with self.assertRaisesRegexp(ValueError, 'num_buckets 0 < 1'):
+ fc.categorical_column_with_identity(key='aaa', num_buckets=0)
+
+ def test_invalid_num_buckets_negative(self):
+ with self.assertRaisesRegexp(ValueError, 'num_buckets -1 < 1'):
+ fc.categorical_column_with_identity(key='aaa', num_buckets=-1)
+
+ def test_invalid_default_value_too_small(self):
+ with self.assertRaisesRegexp(ValueError, 'default_value -1 not in range'):
+ fc.categorical_column_with_identity(
+ key='aaa', num_buckets=3, default_value=-1)
+
+ def test_invalid_default_value_too_big(self):
+ with self.assertRaisesRegexp(ValueError, 'default_value 3 not in range'):
+ fc.categorical_column_with_identity(
+ key='aaa', num_buckets=3, default_value=3)
+
+ def test_invalid_input_dtype(self):
+ column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('omar', 'stringer', 'marlo'),
+ dense_shape=(2, 2))
+ with self.assertRaisesRegexp(ValueError, 'Invalid input, not integer'):
+ column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }), None)
+
+ def test_parse_example(self):
+ a = fc.categorical_column_with_identity(key='aaa', num_buckets=30)
+ data = example_pb2.Example(features=feature_pb2.Features(
+ feature={
+ 'aaa':
+ feature_pb2.Feature(int64_list=feature_pb2.Int64List(
+ value=[11, 21]))
+ }))
+ features = parsing_ops.parse_example(
+ serialized=[data.SerializeToString()],
+ features=fc.make_parse_example_spec([a]))
+ self.assertIn('aaa', features)
+ with self.test_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [0, 1]],
+ values=np.array([11, 21], dtype=np.int64),
+ dense_shape=[1, 2]),
+ features['aaa'].eval())
+
+ def test_get_sparse_tensors(self):
+ column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 1, 0),
+ dense_shape=(2, 2))
+ id_weight_pair = column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }), None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array((0, 1, 0), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_weight_pair.id_tensor.eval())
+
+ def test_transform_feature(self):
+ column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 1, 0),
+ dense_shape=(2, 2))
+ id_tensor = _transform_features({'aaa': inputs}, [column], None)[column]
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array((0, 1, 0), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_tensor.eval())
+
+ def DISABLED_test_get_sparse_tensors_weight_collections(self):
+ column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 1, 0),
+ dense_shape=(2, 2))
+ column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }),
+ weight_collections=('my_weights',))
+
+ self.assertItemsEqual(
+ [], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
+ self.assertItemsEqual([], ops.get_collection('my_weights'))
+
+ def test_get_sparse_tensors_dense_input(self):
+ column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
+ id_weight_pair = column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': ((0, -1), (1, 0))
+ }), None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=np.array((0, 1, 0), dtype=np.int64),
+ dense_shape=(2, 2)),
+ id_weight_pair.id_tensor.eval())
+
+ def test_get_sparse_tensors_with_inputs_too_small(self):
+ column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(1, -1, 0),
+ dense_shape=(2, 2))
+ id_weight_pair = column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }), None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ with self.assertRaisesRegexp(
+ errors.OpError, 'assert_greater_or_equal_0'):
+ id_weight_pair.id_tensor.eval()
+
+ def test_get_sparse_tensors_with_inputs_too_big(self):
+ column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(1, 99, 0),
+ dense_shape=(2, 2))
+ id_weight_pair = column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }), None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ with self.assertRaisesRegexp(
+ errors.OpError, 'assert_less_than_num_buckets'):
+ id_weight_pair.id_tensor.eval()
+
+ def test_get_sparse_tensors_with_default_value(self):
+ column = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=4, default_value=3)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(1, -1, 99),
+ dense_shape=(2, 2))
+ id_weight_pair = column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }), None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array((1, 3, 3), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_weight_pair.id_tensor.eval())
+
+ def test_get_sparse_tensors_with_default_value_and_placeholder_inputs(self):
+ column = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=4, default_value=3)
+ input_indices = array_ops.placeholder(dtype=dtypes.int64)
+ input_values = array_ops.placeholder(dtype=dtypes.int32)
+ input_shape = array_ops.placeholder(dtype=dtypes.int64)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=input_indices,
+ values=input_values,
+ dense_shape=input_shape)
+ id_weight_pair = column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }), None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=np.array(((0, 0), (1, 0), (1, 1)), dtype=np.int64),
+ values=np.array((1, 3, 3), dtype=np.int64),
+ dense_shape=np.array((2, 2), dtype=np.int64)),
+ id_weight_pair.id_tensor.eval(feed_dict={
+ input_indices: ((0, 0), (1, 0), (1, 1)),
+ input_values: (1, -1, 99),
+ input_shape: (2, 2),
+ }))
+
+ def test_linear_model(self):
+ column = fc_old.categorical_column_with_identity(key='aaa', num_buckets=3)
+ self.assertEqual(3, column.num_buckets)
+ with ops.Graph().as_default():
+ predictions = fc.linear_model({
+ column.name: sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 2, 1),
+ dense_shape=(2, 2))
+ }, (column,))
+ bias = get_linear_model_bias()
+ weight_var = get_linear_model_column_var(column)
+ with _initialized_session():
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ weight_var.assign(((1.,), (2.,), (3.,))).eval()
+ # weight_var[0] = 1
+ # weight_var[2] + weight_var[1] = 3+2 = 5
+ self.assertAllClose(((1.,), (5.,)), predictions.eval())
+
+ def test_keras_linear_model(self):
+ column = fc_old.categorical_column_with_identity(key='aaa', num_buckets=3)
+ self.assertEqual(3, column.num_buckets)
+ with ops.Graph().as_default():
+ predictions = get_keras_linear_model_predictions({
+ column.name:
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 2, 1),
+ dense_shape=(2, 2))
+ }, (column,))
+ bias = get_linear_model_bias()
+ weight_var = get_linear_model_column_var(column)
+ with _initialized_session():
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ weight_var.assign(((1.,), (2.,), (3.,))).eval()
+ # weight_var[0] = 1
+ # weight_var[2] + weight_var[1] = 3+2 = 5
+ self.assertAllClose(((1.,), (5.,)), predictions.eval())
+
+
+class TransformFeaturesTest(test.TestCase):
+
+ # All transform tests are distributed in column test.
+ # Here we only test multi column case and naming
+ def transform_multi_column(self):
+ bucketized_price = fc.bucketized_column(
+ fc.numeric_column('price'), boundaries=[0, 2, 4, 6])
+ hashed_sparse = fc.categorical_column_with_hash_bucket('wire', 10)
+ with ops.Graph().as_default():
+ features = {
+ 'price': [[-1.], [5.]],
+ 'wire':
+ sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'],
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ }
+ transformed = _transform_features(features,
+ [bucketized_price, hashed_sparse], None)
+ with _initialized_session():
+ self.assertIn(bucketized_price.name, transformed[bucketized_price].name)
+ self.assertAllEqual([[0], [3]], transformed[bucketized_price].eval())
+ self.assertIn(hashed_sparse.name, transformed[hashed_sparse].name)
+ self.assertAllEqual([6, 4, 1], transformed[hashed_sparse].values.eval())
+
+ def test_column_order(self):
+ """When the column is both dense and sparse, uses sparse tensors."""
+
+ class _LoggerColumn(FeatureColumn):
+
+ def __init__(self, name):
+ self._name = name
+
+ @property
+ def name(self):
+ return self._name
+
+ def transform_feature(self, transformation_cache, state_manager):
+ self.call_order = call_logger['count']
+ call_logger['count'] += 1
+ return 'Anything'
+
+ @property
+ def parse_example_spec(self):
+ pass
+
+ with ops.Graph().as_default():
+ column1 = _LoggerColumn('1')
+ column2 = _LoggerColumn('2')
+ call_logger = {'count': 0}
+ _transform_features({}, [column1, column2], None)
+ self.assertEqual(0, column1.call_order)
+ self.assertEqual(1, column2.call_order)
+
+ call_logger = {'count': 0}
+ _transform_features({}, [column2, column1], None)
+ self.assertEqual(0, column1.call_order)
+ self.assertEqual(1, column2.call_order)
+
+
+class IndicatorColumnTest(test.TestCase):
+
+ def test_indicator_column(self):
+ a = fc.categorical_column_with_hash_bucket('a', 4)
+ indicator_a = fc.indicator_column(a)
+ self.assertEqual(indicator_a.categorical_column.name, 'a')
+ self.assertEqual(indicator_a.name, 'a_indicator')
+ self.assertEqual(indicator_a.variable_shape, [1, 4])
+
+ b = fc.categorical_column_with_hash_bucket('b', hash_bucket_size=100)
+ indicator_b = fc.indicator_column(b)
+ self.assertEqual(indicator_b.categorical_column.name, 'b')
+ self.assertEqual(indicator_b.name, 'b_indicator')
+ self.assertEqual(indicator_b.variable_shape, [1, 100])
+
+ def test_1D_shape_succeeds(self):
+ animal = fc.indicator_column(
+ fc.categorical_column_with_hash_bucket('animal', 4))
+ transformation_cache = FeatureTransformationCache({
+ 'animal': ['fox', 'fox']
+ })
+ output = transformation_cache.get(animal, None)
+ with self.test_session():
+ self.assertAllEqual([[0., 0., 1., 0.], [0., 0., 1., 0.]], output.eval())
+
+ def test_2D_shape_succeeds(self):
+ # TODO(ispir/cassandrax): Swith to categorical_column_with_keys when ready.
+ animal = fc.indicator_column(
+ fc.categorical_column_with_hash_bucket('animal', 4))
+ transformation_cache = FeatureTransformationCache({
+ 'animal':
+ sparse_tensor.SparseTensor(
+ indices=[[0, 0], [1, 0]],
+ values=['fox', 'fox'],
+ dense_shape=[2, 1])
+ })
+ output = transformation_cache.get(animal, None)
+ with self.test_session():
+ self.assertAllEqual([[0., 0., 1., 0.], [0., 0., 1., 0.]], output.eval())
+
+ def test_multi_hot(self):
+ animal = fc.indicator_column(
+ fc.categorical_column_with_identity('animal', num_buckets=4))
+
+ transformation_cache = FeatureTransformationCache({
+ 'animal':
+ sparse_tensor.SparseTensor(
+ indices=[[0, 0], [0, 1]], values=[1, 1], dense_shape=[1, 2])
+ })
+ output = transformation_cache.get(animal, None)
+ with self.test_session():
+ self.assertAllEqual([[0., 2., 0., 0.]], output.eval())
+
+ def test_multi_hot2(self):
+ animal = fc.indicator_column(
+ fc.categorical_column_with_identity('animal', num_buckets=4))
+ transformation_cache = FeatureTransformationCache({
+ 'animal':
+ sparse_tensor.SparseTensor(
+ indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2])
+ })
+ output = transformation_cache.get(animal, None)
+ with self.test_session():
+ self.assertAllEqual([[0., 1., 1., 0.]], output.eval())
+
+ def test_deep_copy(self):
+ a = fc.categorical_column_with_hash_bucket('a', 4)
+ column = fc.indicator_column(a)
+ column_copy = copy.deepcopy(column)
+ self.assertEqual(column_copy.categorical_column.name, 'a')
+ self.assertEqual(column.name, 'a_indicator')
+ self.assertEqual(column.variable_shape, [1, 4])
+
+ def test_parse_example(self):
+ a = fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
+ a_indicator = fc.indicator_column(a)
+ data = example_pb2.Example(features=feature_pb2.Features(
+ feature={
+ 'aaa':
+ feature_pb2.Feature(bytes_list=feature_pb2.BytesList(
+ value=[b'omar', b'stringer']))
+ }))
+ features = parsing_ops.parse_example(
+ serialized=[data.SerializeToString()],
+ features=fc.make_parse_example_spec([a_indicator]))
+ self.assertIn('aaa', features)
+ with self.test_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [0, 1]],
+ values=np.array([b'omar', b'stringer'], dtype=np.object_),
+ dense_shape=[1, 2]),
+ features['aaa'].eval())
+
+ def test_transform(self):
+ a = fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
+ a_indicator = fc.indicator_column(a)
+ features = {
+ 'aaa': sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ }
+ indicator_tensor = _transform_features(features, [a_indicator],
+ None)[a_indicator]
+ with _initialized_session():
+ self.assertAllEqual([[0, 0, 1], [1, 0, 0]], indicator_tensor.eval())
+
+ def test_transform_with_weighted_column(self):
+ # Github issue 12557
+ ids = fc.categorical_column_with_vocabulary_list(
+ key='ids', vocabulary_list=('a', 'b', 'c'))
+ weights = fc.weighted_categorical_column(ids, 'weights')
+ indicator = fc.indicator_column(weights)
+ features = {
+ 'ids': constant_op.constant([['c', 'b', 'a']]),
+ 'weights': constant_op.constant([[2., 4., 6.]])
+ }
+ indicator_tensor = _transform_features(features, [indicator],
+ None)[indicator]
+ with _initialized_session():
+ self.assertAllEqual([[6., 4., 2.]], indicator_tensor.eval())
+
+ def test_transform_with_missing_value_in_weighted_column(self):
+ # Github issue 12583
+ ids = fc.categorical_column_with_vocabulary_list(
+ key='ids', vocabulary_list=('a', 'b', 'c'))
+ weights = fc.weighted_categorical_column(ids, 'weights')
+ indicator = fc.indicator_column(weights)
+ features = {
+ 'ids': constant_op.constant([['c', 'b', 'unknown']]),
+ 'weights': constant_op.constant([[2., 4., 6.]])
+ }
+ indicator_tensor = _transform_features(features, [indicator],
+ None)[indicator]
+ with _initialized_session():
+ self.assertAllEqual([[0., 4., 2.]], indicator_tensor.eval())
+
+ def test_transform_with_missing_value_in_categorical_column(self):
+ # Github issue 12583
+ ids = fc.categorical_column_with_vocabulary_list(
+ key='ids', vocabulary_list=('a', 'b', 'c'))
+ indicator = fc.indicator_column(ids)
+ features = {
+ 'ids': constant_op.constant([['c', 'b', 'unknown']]),
+ }
+ indicator_tensor = _transform_features(features, [indicator],
+ None)[indicator]
+ with _initialized_session():
+ self.assertAllEqual([[0., 1., 1.]], indicator_tensor.eval())
+
+ def test_linear_model(self):
+ animal = fc_old.indicator_column(
+ fc_old.categorical_column_with_identity('animal', num_buckets=4))
+ with ops.Graph().as_default():
+ features = {
+ 'animal':
+ sparse_tensor.SparseTensor(
+ indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2])
+ }
+
+ predictions = fc.linear_model(features, [animal])
+ weight_var = get_linear_model_column_var(animal)
+ with _initialized_session():
+ # All should be zero-initialized.
+ self.assertAllClose([[0.], [0.], [0.], [0.]], weight_var.eval())
+ self.assertAllClose([[0.]], predictions.eval())
+ weight_var.assign([[1.], [2.], [3.], [4.]]).eval()
+ self.assertAllClose([[2. + 3.]], predictions.eval())
+
+ def test_keras_linear_model(self):
+ animal = fc_old.indicator_column(
+ fc_old.categorical_column_with_identity('animal', num_buckets=4))
+ with ops.Graph().as_default():
+ features = {
+ 'animal':
+ sparse_tensor.SparseTensor(
+ indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2])
+ }
+
+ predictions = get_keras_linear_model_predictions(features, [animal])
+ weight_var = get_linear_model_column_var(animal)
+ with _initialized_session():
+ # All should be zero-initialized.
+ self.assertAllClose([[0.], [0.], [0.], [0.]], weight_var.eval())
+ self.assertAllClose([[0.]], predictions.eval())
+ weight_var.assign([[1.], [2.], [3.], [4.]]).eval()
+ self.assertAllClose([[2. + 3.]], predictions.eval())
+
+ def test_input_layer(self):
+ animal = fc_old.indicator_column(
+ fc_old.categorical_column_with_identity('animal', num_buckets=4))
+ with ops.Graph().as_default():
+ features = {
+ 'animal':
+ sparse_tensor.SparseTensor(
+ indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2])
+ }
+ net = fc.input_layer(features, [animal])
+ with _initialized_session():
+ self.assertAllClose([[0., 1., 1., 0.]], net.eval())
+
+
+class _TestStateManager(StateManager):
+
+ def __init__(self, trainable=True):
+ # Dict of feature_column to a dict of variables.
+ self._all_variables = {}
+ self._trainable = trainable
+
+ def get_variable(self,
+ feature_column,
+ name,
+ shape,
+ dtype=None,
+ initializer=None):
+ if feature_column not in self._all_variables:
+ self._all_variables[feature_column] = {}
+ var_dict = self._all_variables[feature_column]
+ if name in var_dict:
+ return var_dict[name]
+ else:
+ var = variable_scope.get_variable(
+ name=name,
+ shape=shape,
+ initializer=initializer,
+ trainable=self._trainable)
+ var_dict[name] = var
+ return var
+
+
+class EmbeddingColumnTest(test.TestCase):
+
+ def test_defaults(self):
+ categorical_column = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=3)
+ embedding_dimension = 2
+ embedding_column = fc.embedding_column(
+ categorical_column, dimension=embedding_dimension)
+ self.assertIs(categorical_column, embedding_column.categorical_column)
+ self.assertEqual(embedding_dimension, embedding_column.dimension)
+ self.assertEqual('mean', embedding_column.combiner)
+ self.assertIsNone(embedding_column.ckpt_to_load_from)
+ self.assertIsNone(embedding_column.tensor_name_in_ckpt)
+ self.assertIsNone(embedding_column.max_norm)
+ self.assertTrue(embedding_column.trainable)
+ self.assertEqual('aaa_embedding', embedding_column.name)
+ self.assertEqual((embedding_dimension,), embedding_column.variable_shape)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.int64)
+ }, embedding_column.parse_example_spec)
+
+ def test_all_constructor_args(self):
+ categorical_column = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=3)
+ embedding_dimension = 2
+ embedding_column = fc.embedding_column(
+ categorical_column, dimension=embedding_dimension,
+ combiner='my_combiner', initializer=lambda: 'my_initializer',
+ ckpt_to_load_from='my_ckpt', tensor_name_in_ckpt='my_ckpt_tensor',
+ max_norm=42., trainable=False)
+ self.assertIs(categorical_column, embedding_column.categorical_column)
+ self.assertEqual(embedding_dimension, embedding_column.dimension)
+ self.assertEqual('my_combiner', embedding_column.combiner)
+ self.assertEqual('my_ckpt', embedding_column.ckpt_to_load_from)
+ self.assertEqual('my_ckpt_tensor', embedding_column.tensor_name_in_ckpt)
+ self.assertEqual(42., embedding_column.max_norm)
+ self.assertFalse(embedding_column.trainable)
+ self.assertEqual('aaa_embedding', embedding_column.name)
+ self.assertEqual((embedding_dimension,), embedding_column.variable_shape)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.int64)
+ }, embedding_column.parse_example_spec)
+
+ def test_deep_copy(self):
+ categorical_column = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=3)
+ embedding_dimension = 2
+ original = fc.embedding_column(
+ categorical_column, dimension=embedding_dimension,
+ combiner='my_combiner', initializer=lambda: 'my_initializer',
+ ckpt_to_load_from='my_ckpt', tensor_name_in_ckpt='my_ckpt_tensor',
+ max_norm=42., trainable=False)
+ for embedding_column in (original, copy.deepcopy(original)):
+ self.assertEqual('aaa', embedding_column.categorical_column.name)
+ self.assertEqual(3, embedding_column.categorical_column.num_buckets)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.int64)
+ }, embedding_column.categorical_column.parse_example_spec)
+
+ self.assertEqual(embedding_dimension, embedding_column.dimension)
+ self.assertEqual('my_combiner', embedding_column.combiner)
+ self.assertEqual('my_ckpt', embedding_column.ckpt_to_load_from)
+ self.assertEqual('my_ckpt_tensor', embedding_column.tensor_name_in_ckpt)
+ self.assertEqual(42., embedding_column.max_norm)
+ self.assertFalse(embedding_column.trainable)
+ self.assertEqual('aaa_embedding', embedding_column.name)
+ self.assertEqual((embedding_dimension,), embedding_column.variable_shape)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.int64)
+ }, embedding_column.parse_example_spec)
+
+ def test_invalid_initializer(self):
+ categorical_column = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=3)
+ with self.assertRaisesRegexp(ValueError, 'initializer must be callable'):
+ fc.embedding_column(categorical_column, dimension=2, initializer='not_fn')
+
+ def test_parse_example(self):
+ a = fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
+ a_embedded = fc.embedding_column(a, dimension=2)
+ data = example_pb2.Example(features=feature_pb2.Features(
+ feature={
+ 'aaa':
+ feature_pb2.Feature(bytes_list=feature_pb2.BytesList(
+ value=[b'omar', b'stringer']))
+ }))
+ features = parsing_ops.parse_example(
+ serialized=[data.SerializeToString()],
+ features=fc.make_parse_example_spec([a_embedded]))
+ self.assertIn('aaa', features)
+ with self.test_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [0, 1]],
+ values=np.array([b'omar', b'stringer'], dtype=np.object_),
+ dense_shape=[1, 2]),
+ features['aaa'].eval())
+
+ def test_transform_feature(self):
+ a = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
+ a_embedded = fc.embedding_column(a, dimension=2)
+ features = {
+ 'aaa': sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 1, 0),
+ dense_shape=(2, 2))
+ }
+ outputs = _transform_features(features, [a, a_embedded], None)
+ output_a = outputs[a]
+ output_embedded = outputs[a_embedded]
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self, output_a.eval(), output_embedded.eval())
+
+ def test_get_dense_tensor(self):
+ # Inputs.
+ vocabulary_size = 3
+ sparse_input = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ # example 2, ids []
+ # example 3, ids [1]
+ indices=((0, 0), (1, 0), (1, 4), (3, 0)),
+ values=(2, 0, 1, 1),
+ dense_shape=(4, 5))
+
+ # Embedding variable.
+ embedding_dimension = 2
+ embedding_values = (
+ (1., 2.), # id 0
+ (3., 5.), # id 1
+ (7., 11.) # id 2
+ )
+ def _initializer(shape, dtype, partition_info):
+ self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
+ self.assertEqual(dtypes.float32, dtype)
+ self.assertIsNone(partition_info)
+ return embedding_values
+
+ # Expected lookup result, using combiner='mean'.
+ expected_lookups = (
+ # example 0, ids [2], embedding = [7, 11]
+ (7., 11.),
+ # example 1, ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
+ (2., 3.5),
+ # example 2, ids [], embedding = [0, 0]
+ (0., 0.),
+ # example 3, ids [1], embedding = [3, 5]
+ (3., 5.),
+ )
+
+ # Build columns.
+ categorical_column = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ embedding_column = fc.embedding_column(
+ categorical_column, dimension=embedding_dimension,
+ initializer=_initializer)
+ state_manager = _TestStateManager()
+
+ # Provide sparse input and get dense result.
+ embedding_lookup = embedding_column.get_dense_tensor(
+ FeatureTransformationCache({
+ 'aaa': sparse_input
+ }), state_manager)
+
+ # Assert expected embedding variable and lookups.
+ global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ self.assertItemsEqual(('embedding_weights:0',),
+ tuple([v.name for v in global_vars]))
+ with _initialized_session():
+ self.assertAllEqual(embedding_values, global_vars[0].eval())
+ self.assertAllEqual(expected_lookups, embedding_lookup.eval())
+
+ def test_get_dense_tensor_3d(self):
+ # Inputs.
+ vocabulary_size = 4
+ sparse_input = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ # example 2, ids []
+ # example 3, ids [1]
+ indices=((0, 0, 0), (1, 1, 0), (1, 1, 4), (3, 0, 0), (3, 1, 2)),
+ values=(2, 0, 1, 1, 2),
+ dense_shape=(4, 2, 5))
+
+ # Embedding variable.
+ embedding_dimension = 3
+ embedding_values = (
+ (1., 2., 4.), # id 0
+ (3., 5., 1.), # id 1
+ (7., 11., 2.), # id 2
+ (2., 7., 12.) # id 3
+ )
+ def _initializer(shape, dtype, partition_info):
+ self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
+ self.assertEqual(dtypes.float32, dtype)
+ self.assertIsNone(partition_info)
+ return embedding_values
+
+ # Expected lookup result, using combiner='mean'.
+ expected_lookups = (
+ # example 0, ids [[2], []], embedding = [[7, 11, 2], [0, 0, 0]]
+ ((7., 11., 2.), (0., 0., 0.)),
+ # example 1, ids [[], [0, 1]], embedding
+ # = mean([[], [1, 2, 4] + [3, 5, 1]]) = [[0, 0, 0], [2, 3.5, 2.5]]
+ ((0., 0., 0.), (2., 3.5, 2.5)),
+ # example 2, ids [[], []], embedding = [[0, 0, 0], [0, 0, 0]]
+ ((0., 0., 0.), (0., 0., 0.)),
+ # example 3, ids [[1], [2]], embedding = [[3, 5, 1], [7, 11, 2]]
+ ((3., 5., 1.), (7., 11., 2.)),
+ )
+
+ # Build columns.
+ categorical_column = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ embedding_column = fc.embedding_column(
+ categorical_column, dimension=embedding_dimension,
+ initializer=_initializer)
+ state_manager = _TestStateManager()
+
+ # Provide sparse input and get dense result.
+ embedding_lookup = embedding_column.get_dense_tensor(
+ FeatureTransformationCache({
+ 'aaa': sparse_input
+ }), state_manager)
+
+ # Assert expected embedding variable and lookups.
+ global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ self.assertItemsEqual(('embedding_weights:0',),
+ tuple([v.name for v in global_vars]))
+ with _initialized_session():
+ self.assertAllEqual(embedding_values, global_vars[0].eval())
+ self.assertAllEqual(expected_lookups, embedding_lookup.eval())
+
+ def DISABLED_test_get_dense_tensor_weight_collections(self):
+ sparse_input = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ # example 2, ids []
+ # example 3, ids [1]
+ indices=((0, 0), (1, 0), (1, 4), (3, 0)),
+ values=(2, 0, 1, 1),
+ dense_shape=(4, 5))
+
+ # Build columns.
+ categorical_column = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=3)
+ embedding_column = fc.embedding_column(categorical_column, dimension=2)
+
+ # Provide sparse input and get dense result.
+ embedding_column.get_dense_tensor(
+ FeatureTransformationCache({
+ 'aaa': sparse_input
+ }),
+ weight_collections=('my_vars',))
+
+ # Assert expected embedding variable and lookups.
+ global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ self.assertItemsEqual(('embedding_weights:0',),
+ tuple([v.name for v in global_vars]))
+ my_vars = ops.get_collection('my_vars')
+ self.assertItemsEqual(
+ ('embedding_weights:0',), tuple([v.name for v in my_vars]))
+
+ def test_get_dense_tensor_placeholder_inputs(self):
+ # Inputs.
+ vocabulary_size = 3
+ sparse_input = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ # example 2, ids []
+ # example 3, ids [1]
+ indices=((0, 0), (1, 0), (1, 4), (3, 0)),
+ values=(2, 0, 1, 1),
+ dense_shape=(4, 5))
+
+ # Embedding variable.
+ embedding_dimension = 2
+ embedding_values = (
+ (1., 2.), # id 0
+ (3., 5.), # id 1
+ (7., 11.) # id 2
+ )
+ def _initializer(shape, dtype, partition_info):
+ self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
+ self.assertEqual(dtypes.float32, dtype)
+ self.assertIsNone(partition_info)
+ return embedding_values
+
+ # Expected lookup result, using combiner='mean'.
+ expected_lookups = (
+ # example 0, ids [2], embedding = [7, 11]
+ (7., 11.),
+ # example 1, ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
+ (2., 3.5),
+ # example 2, ids [], embedding = [0, 0]
+ (0., 0.),
+ # example 3, ids [1], embedding = [3, 5]
+ (3., 5.),
+ )
+
+ # Build columns.
+ categorical_column = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ embedding_column = fc.embedding_column(
+ categorical_column, dimension=embedding_dimension,
+ initializer=_initializer)
+ state_manager = _TestStateManager()
+
+ # Provide sparse input and get dense result.
+ input_indices = array_ops.placeholder(dtype=dtypes.int64)
+ input_values = array_ops.placeholder(dtype=dtypes.int64)
+ input_shape = array_ops.placeholder(dtype=dtypes.int64)
+ embedding_lookup = embedding_column.get_dense_tensor(
+ FeatureTransformationCache({
+ 'aaa':
+ sparse_tensor.SparseTensorValue(
+ indices=input_indices,
+ values=input_values,
+ dense_shape=input_shape)
+ }), state_manager)
+
+ # Assert expected embedding variable and lookups.
+ global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ self.assertItemsEqual(
+ ('embedding_weights:0',), tuple([v.name for v in global_vars]))
+ with _initialized_session():
+ self.assertAllEqual(embedding_values, global_vars[0].eval())
+ self.assertAllEqual(expected_lookups, embedding_lookup.eval(
+ feed_dict={
+ input_indices: sparse_input.indices,
+ input_values: sparse_input.values,
+ input_shape: sparse_input.dense_shape,
+ }))
+
+ def test_get_dense_tensor_restore_from_ckpt(self):
+ # Inputs.
+ vocabulary_size = 3
+ sparse_input = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ # example 2, ids []
+ # example 3, ids [1]
+ indices=((0, 0), (1, 0), (1, 4), (3, 0)),
+ values=(2, 0, 1, 1),
+ dense_shape=(4, 5))
+
+ # Embedding variable. The checkpoint file contains _embedding_values.
+ embedding_dimension = 2
+ embedding_values = (
+ (1., 2.), # id 0
+ (3., 5.), # id 1
+ (7., 11.) # id 2
+ )
+ ckpt_path = test.test_src_dir_path(
+ 'python/feature_column/testdata/embedding.ckpt')
+ ckpt_tensor = 'my_embedding'
+
+ # Expected lookup result, using combiner='mean'.
+ expected_lookups = (
+ # example 0, ids [2], embedding = [7, 11]
+ (7., 11.),
+ # example 1, ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
+ (2., 3.5),
+ # example 2, ids [], embedding = [0, 0]
+ (0., 0.),
+ # example 3, ids [1], embedding = [3, 5]
+ (3., 5.),
+ )
+
+ # Build columns.
+ categorical_column = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ embedding_column = fc.embedding_column(
+ categorical_column, dimension=embedding_dimension,
+ ckpt_to_load_from=ckpt_path,
+ tensor_name_in_ckpt=ckpt_tensor)
+ state_manager = _TestStateManager()
+
+ # Provide sparse input and get dense result.
+ embedding_lookup = embedding_column.get_dense_tensor(
+ FeatureTransformationCache({
+ 'aaa': sparse_input
+ }), state_manager)
+
+ # Assert expected embedding variable and lookups.
+ global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ self.assertItemsEqual(
+ ('embedding_weights:0',), tuple([v.name for v in global_vars]))
+ with _initialized_session():
+ self.assertAllEqual(embedding_values, global_vars[0].eval())
+ self.assertAllEqual(expected_lookups, embedding_lookup.eval())
+
+ def test_linear_model(self):
+ # Inputs.
+ batch_size = 4
+ vocabulary_size = 3
+ sparse_input = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ # example 2, ids []
+ # example 3, ids [1]
+ indices=((0, 0), (1, 0), (1, 4), (3, 0)),
+ values=(2, 0, 1, 1),
+ dense_shape=(batch_size, 5))
+
+ # Embedding variable.
+ embedding_dimension = 2
+ embedding_shape = (vocabulary_size, embedding_dimension)
+ zeros_embedding_values = np.zeros(embedding_shape)
+ def _initializer(shape, dtype, partition_info):
+ self.assertAllEqual(embedding_shape, shape)
+ self.assertEqual(dtypes.float32, dtype)
+ self.assertIsNone(partition_info)
+ return zeros_embedding_values
+
+ # Build columns.
+ categorical_column = fc_old.categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ embedding_column = fc_old.embedding_column(
+ categorical_column,
+ dimension=embedding_dimension,
+ initializer=_initializer)
+
+ with ops.Graph().as_default():
+ predictions = fc.linear_model({
+ categorical_column.name: sparse_input
+ }, (embedding_column,))
+ expected_var_names = (
+ 'linear_model/bias_weights:0',
+ 'linear_model/aaa_embedding/weights:0',
+ 'linear_model/aaa_embedding/embedding_weights:0',
+ )
+ self.assertItemsEqual(
+ expected_var_names,
+ [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
+ trainable_vars = {
+ v.name: v for v in ops.get_collection(
+ ops.GraphKeys.TRAINABLE_VARIABLES)
+ }
+ self.assertItemsEqual(expected_var_names, trainable_vars.keys())
+ bias = trainable_vars['linear_model/bias_weights:0']
+ embedding_weights = trainable_vars[
+ 'linear_model/aaa_embedding/embedding_weights:0']
+ linear_weights = trainable_vars[
+ 'linear_model/aaa_embedding/weights:0']
+ with _initialized_session():
+ # Predictions with all zero weights.
+ self.assertAllClose(np.zeros((1,)), bias.eval())
+ self.assertAllClose(zeros_embedding_values, embedding_weights.eval())
+ self.assertAllClose(
+ np.zeros((embedding_dimension, 1)), linear_weights.eval())
+ self.assertAllClose(np.zeros((batch_size, 1)), predictions.eval())
+
+ # Predictions with all non-zero weights.
+ embedding_weights.assign((
+ (1., 2.), # id 0
+ (3., 5.), # id 1
+ (7., 11.) # id 2
+ )).eval()
+ linear_weights.assign(((4.,), (6.,))).eval()
+ # example 0, ids [2], embedding[0] = [7, 11]
+ # example 1, ids [0, 1], embedding[1] = mean([1, 2] + [3, 5]) = [2, 3.5]
+ # example 2, ids [], embedding[2] = [0, 0]
+ # example 3, ids [1], embedding[3] = [3, 5]
+ # sum(embeddings * linear_weights)
+ # = [4*7 + 6*11, 4*2 + 6*3.5, 4*0 + 6*0, 4*3 + 6*5] = [94, 29, 0, 42]
+ self.assertAllClose(((94.,), (29.,), (0.,), (42.,)), predictions.eval())
+
+ def test_keras_linear_model(self):
+ # Inputs.
+ batch_size = 4
+ vocabulary_size = 3
+ sparse_input = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ # example 2, ids []
+ # example 3, ids [1]
+ indices=((0, 0), (1, 0), (1, 4), (3, 0)),
+ values=(2, 0, 1, 1),
+ dense_shape=(batch_size, 5))
+
+ # Embedding variable.
+ embedding_dimension = 2
+ embedding_shape = (vocabulary_size, embedding_dimension)
+ zeros_embedding_values = np.zeros(embedding_shape)
+
+ def _initializer(shape, dtype, partition_info):
+ self.assertAllEqual(embedding_shape, shape)
+ self.assertEqual(dtypes.float32, dtype)
+ self.assertIsNone(partition_info)
+ return zeros_embedding_values
+
+ # Build columns.
+ categorical_column = fc_old.categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ embedding_column = fc_old.embedding_column(
+ categorical_column,
+ dimension=embedding_dimension,
+ initializer=_initializer)
+
+ with ops.Graph().as_default():
+ predictions = get_keras_linear_model_predictions({
+ categorical_column.name: sparse_input
+ }, (embedding_column,))
+ expected_var_names = (
+ 'linear_model/bias_weights:0',
+ 'linear_model/aaa_embedding/weights:0',
+ 'linear_model/aaa_embedding/embedding_weights:0',
+ )
+ self.assertItemsEqual(
+ expected_var_names,
+ [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
+ trainable_vars = {
+ v.name: v
+ for v in ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ }
+ self.assertItemsEqual(expected_var_names, trainable_vars.keys())
+ bias = trainable_vars['linear_model/bias_weights:0']
+ embedding_weights = trainable_vars[
+ 'linear_model/aaa_embedding/embedding_weights:0']
+ linear_weights = trainable_vars['linear_model/aaa_embedding/weights:0']
+ with _initialized_session():
+ # Predictions with all zero weights.
+ self.assertAllClose(np.zeros((1,)), bias.eval())
+ self.assertAllClose(zeros_embedding_values, embedding_weights.eval())
+ self.assertAllClose(
+ np.zeros((embedding_dimension, 1)), linear_weights.eval())
+ self.assertAllClose(np.zeros((batch_size, 1)), predictions.eval())
+
+ # Predictions with all non-zero weights.
+ embedding_weights.assign((
+ (1., 2.), # id 0
+ (3., 5.), # id 1
+ (7., 11.) # id 2
+ )).eval()
+ linear_weights.assign(((4.,), (6.,))).eval()
+ # example 0, ids [2], embedding[0] = [7, 11]
+ # example 1, ids [0, 1], embedding[1] = mean([1, 2] + [3, 5]) = [2, 3.5]
+ # example 2, ids [], embedding[2] = [0, 0]
+ # example 3, ids [1], embedding[3] = [3, 5]
+ # sum(embeddings * linear_weights)
+ # = [4*7 + 6*11, 4*2 + 6*3.5, 4*0 + 6*0, 4*3 + 6*5] = [94, 29, 0, 42]
+ self.assertAllClose(((94.,), (29.,), (0.,), (42.,)), predictions.eval())
+
+ def test_input_layer(self):
+ # Inputs.
+ vocabulary_size = 3
+ sparse_input = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ # example 2, ids []
+ # example 3, ids [1]
+ indices=((0, 0), (1, 0), (1, 4), (3, 0)),
+ values=(2, 0, 1, 1),
+ dense_shape=(4, 5))
+
+ # Embedding variable.
+ embedding_dimension = 2
+ embedding_values = (
+ (1., 2.), # id 0
+ (3., 5.), # id 1
+ (7., 11.) # id 2
+ )
+ def _initializer(shape, dtype, partition_info):
+ self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
+ self.assertEqual(dtypes.float32, dtype)
+ self.assertIsNone(partition_info)
+ return embedding_values
+
+ # Expected lookup result, using combiner='mean'.
+ expected_lookups = (
+ # example 0, ids [2], embedding = [7, 11]
+ (7., 11.),
+ # example 1, ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
+ (2., 3.5),
+ # example 2, ids [], embedding = [0, 0]
+ (0., 0.),
+ # example 3, ids [1], embedding = [3, 5]
+ (3., 5.),
+ )
+
+ # Build columns.
+ categorical_column = fc_old.categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ embedding_column = fc_old.embedding_column(
+ categorical_column,
+ dimension=embedding_dimension,
+ initializer=_initializer)
+
+ # Provide sparse input and get dense result.
+ input_layer = fc.input_layer({'aaa': sparse_input}, (embedding_column,))
+
+ # Assert expected embedding variable and lookups.
+ global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ self.assertItemsEqual(
+ ('input_layer/aaa_embedding/embedding_weights:0',),
+ tuple([v.name for v in global_vars]))
+ trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ self.assertItemsEqual(
+ ('input_layer/aaa_embedding/embedding_weights:0',),
+ tuple([v.name for v in trainable_vars]))
+ with _initialized_session():
+ self.assertAllEqual(embedding_values, trainable_vars[0].eval())
+ self.assertAllEqual(expected_lookups, input_layer.eval())
+
+ def test_input_layer_not_trainable(self):
+ # Inputs.
+ vocabulary_size = 3
+ sparse_input = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ # example 2, ids []
+ # example 3, ids [1]
+ indices=((0, 0), (1, 0), (1, 4), (3, 0)),
+ values=(2, 0, 1, 1),
+ dense_shape=(4, 5))
+
+ # Embedding variable.
+ embedding_dimension = 2
+ embedding_values = (
+ (1., 2.), # id 0
+ (3., 5.), # id 1
+ (7., 11.) # id 2
+ )
+ def _initializer(shape, dtype, partition_info):
+ self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
+ self.assertEqual(dtypes.float32, dtype)
+ self.assertIsNone(partition_info)
+ return embedding_values
+
+ # Expected lookup result, using combiner='mean'.
+ expected_lookups = (
+ # example 0, ids [2], embedding = [7, 11]
+ (7., 11.),
+ # example 1, ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
+ (2., 3.5),
+ # example 2, ids [], embedding = [0, 0]
+ (0., 0.),
+ # example 3, ids [1], embedding = [3, 5]
+ (3., 5.),
+ )
+
+ # Build columns.
+ categorical_column = fc_old.categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ embedding_column = fc_old.embedding_column(
+ categorical_column,
+ dimension=embedding_dimension,
+ initializer=_initializer,
+ trainable=False)
+
+ # Provide sparse input and get dense result.
+ input_layer = fc.input_layer({'aaa': sparse_input}, (embedding_column,))
+
+ # Assert expected embedding variable and lookups.
+ global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ self.assertItemsEqual(
+ ('input_layer/aaa_embedding/embedding_weights:0',),
+ tuple([v.name for v in global_vars]))
+ self.assertItemsEqual(
+ [], ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
+ with _initialized_session():
+ self.assertAllEqual(embedding_values, global_vars[0].eval())
+ self.assertAllEqual(expected_lookups, input_layer.eval())
+
+
+class _TestSharedEmbeddingStateManager(StateManager):
+ """Manages the state for shared embedding columns.
+
+ This can handle multiple groups of shared embedding columns.
+ """
+
+ def __init__(self, trainable=True):
+ # Dict of shared_embedding_collection_name to a dict of variables.
+ self._all_variables = {}
+ self._trainable = trainable
+
+ def get_variable(self,
+ feature_column,
+ name,
+ shape,
+ dtype=None,
+ initializer=None):
+ if not isinstance(feature_column, fc.SharedEmbeddingColumn):
+ raise ValueError(
+ 'SharedEmbeddingStateManager can only handle SharedEmbeddingColumns. '
+ 'Given type: {} '.format(type(feature_column)))
+
+ collection_name = feature_column.shared_collection_name
+ if collection_name not in self._all_variables:
+ self._all_variables[collection_name] = {}
+ var_dict = self._all_variables[collection_name]
+ if name in var_dict:
+ return var_dict[name]
+ else:
+ var = variable_scope.get_variable(
+ name=name,
+ shape=shape,
+ initializer=initializer,
+ trainable=self._trainable)
+ var_dict[name] = var
+ return var
+
+
+class SharedEmbeddingColumnTest(test.TestCase):
+
+ def test_defaults(self):
+ categorical_column_a = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=3)
+ categorical_column_b = fc.categorical_column_with_identity(
+ key='bbb', num_buckets=3)
+ embedding_dimension = 2
+ embedding_column_b, embedding_column_a = fc.shared_embedding_columns(
+ [categorical_column_b, categorical_column_a],
+ dimension=embedding_dimension)
+ self.assertIs(categorical_column_a, embedding_column_a.categorical_column)
+ self.assertIs(categorical_column_b, embedding_column_b.categorical_column)
+ self.assertEqual(embedding_dimension, embedding_column_a.dimension)
+ self.assertEqual(embedding_dimension, embedding_column_b.dimension)
+ self.assertEqual('mean', embedding_column_a.combiner)
+ self.assertEqual('mean', embedding_column_b.combiner)
+ self.assertIsNone(embedding_column_a.ckpt_to_load_from)
+ self.assertIsNone(embedding_column_b.ckpt_to_load_from)
+ self.assertEqual('aaa_bbb_shared_embedding',
+ embedding_column_a.shared_collection_name)
+ self.assertEqual('aaa_bbb_shared_embedding',
+ embedding_column_b.shared_collection_name)
+ self.assertIsNone(embedding_column_a.tensor_name_in_ckpt)
+ self.assertIsNone(embedding_column_b.tensor_name_in_ckpt)
+ self.assertIsNone(embedding_column_a.max_norm)
+ self.assertIsNone(embedding_column_b.max_norm)
+ self.assertTrue(embedding_column_a.trainable)
+ self.assertTrue(embedding_column_b.trainable)
+ self.assertEqual('aaa_shared_embedding', embedding_column_a.name)
+ self.assertEqual('bbb_shared_embedding', embedding_column_b.name)
+ self.assertEqual((embedding_dimension,), embedding_column_a.variable_shape)
+ self.assertEqual((embedding_dimension,), embedding_column_b.variable_shape)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.int64)
+ }, embedding_column_a.parse_example_spec)
+ self.assertEqual({
+ 'bbb': parsing_ops.VarLenFeature(dtypes.int64)
+ }, embedding_column_b.parse_example_spec)
+
+ def test_all_constructor_args(self):
+ categorical_column_a = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=3)
+ categorical_column_b = fc.categorical_column_with_identity(
+ key='bbb', num_buckets=3)
+ embedding_dimension = 2
+ embedding_column_a, embedding_column_b = fc.shared_embedding_columns(
+ [categorical_column_a, categorical_column_b],
+ dimension=embedding_dimension,
+ combiner='my_combiner',
+ initializer=lambda: 'my_initializer',
+ shared_embedding_collection_name='shared_embedding_collection_name',
+ ckpt_to_load_from='my_ckpt',
+ tensor_name_in_ckpt='my_ckpt_tensor',
+ max_norm=42.,
+ trainable=False)
+ self.assertIs(categorical_column_a, embedding_column_a.categorical_column)
+ self.assertIs(categorical_column_b, embedding_column_b.categorical_column)
+ self.assertEqual(embedding_dimension, embedding_column_a.dimension)
+ self.assertEqual(embedding_dimension, embedding_column_b.dimension)
+ self.assertEqual('my_combiner', embedding_column_a.combiner)
+ self.assertEqual('my_combiner', embedding_column_b.combiner)
+ self.assertEqual('shared_embedding_collection_name',
+ embedding_column_a.shared_collection_name)
+ self.assertEqual('shared_embedding_collection_name',
+ embedding_column_b.shared_collection_name)
+ self.assertEqual('my_ckpt', embedding_column_a.ckpt_to_load_from)
+ self.assertEqual('my_ckpt', embedding_column_b.ckpt_to_load_from)
+ self.assertEqual('my_ckpt_tensor', embedding_column_a.tensor_name_in_ckpt)
+ self.assertEqual('my_ckpt_tensor', embedding_column_b.tensor_name_in_ckpt)
+ self.assertEqual(42., embedding_column_a.max_norm)
+ self.assertEqual(42., embedding_column_b.max_norm)
+ self.assertFalse(embedding_column_a.trainable)
+ self.assertFalse(embedding_column_b.trainable)
+ self.assertEqual('aaa_shared_embedding', embedding_column_a.name)
+ self.assertEqual('bbb_shared_embedding', embedding_column_b.name)
+ self.assertEqual((embedding_dimension,), embedding_column_a.variable_shape)
+ self.assertEqual((embedding_dimension,), embedding_column_b.variable_shape)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.int64)
+ }, embedding_column_a.parse_example_spec)
+ self.assertEqual({
+ 'bbb': parsing_ops.VarLenFeature(dtypes.int64)
+ }, embedding_column_b.parse_example_spec)
+
+ def test_deep_copy(self):
+ categorical_column_a = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=3)
+ categorical_column_b = fc.categorical_column_with_identity(
+ key='bbb', num_buckets=3)
+ embedding_dimension = 2
+ original_a, _ = fc.shared_embedding_columns(
+ [categorical_column_a, categorical_column_b],
+ dimension=embedding_dimension,
+ combiner='my_combiner',
+ initializer=lambda: 'my_initializer',
+ shared_embedding_collection_name='shared_embedding_collection_name',
+ ckpt_to_load_from='my_ckpt',
+ tensor_name_in_ckpt='my_ckpt_tensor',
+ max_norm=42., trainable=False)
+ for embedding_column_a in (original_a, copy.deepcopy(original_a)):
+ self.assertEqual('aaa', embedding_column_a.categorical_column.name)
+ self.assertEqual(3, embedding_column_a.categorical_column.num_buckets)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.int64)
+ }, embedding_column_a.categorical_column.parse_example_spec)
+
+ self.assertEqual(embedding_dimension, embedding_column_a.dimension)
+ self.assertEqual('my_combiner', embedding_column_a.combiner)
+ self.assertEqual('shared_embedding_collection_name',
+ embedding_column_a.shared_collection_name)
+ self.assertEqual('my_ckpt', embedding_column_a.ckpt_to_load_from)
+ self.assertEqual('my_ckpt_tensor', embedding_column_a.tensor_name_in_ckpt)
+ self.assertEqual(42., embedding_column_a.max_norm)
+ self.assertFalse(embedding_column_a.trainable)
+ self.assertEqual('aaa_shared_embedding', embedding_column_a.name)
+ self.assertEqual((embedding_dimension,),
+ embedding_column_a.variable_shape)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.int64)
+ }, embedding_column_a.parse_example_spec)
+
+ def test_invalid_initializer(self):
+ categorical_column_a = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=3)
+ categorical_column_b = fc.categorical_column_with_identity(
+ key='bbb', num_buckets=3)
+ with self.assertRaisesRegexp(ValueError, 'initializer must be callable'):
+ fc.shared_embedding_columns(
+ [categorical_column_a, categorical_column_b], dimension=2,
+ initializer='not_fn')
+
+ def test_incompatible_column_type(self):
+ categorical_column_a = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=3)
+ categorical_column_b = fc.categorical_column_with_identity(
+ key='bbb', num_buckets=3)
+ categorical_column_c = fc.categorical_column_with_hash_bucket(
+ key='ccc', hash_bucket_size=3)
+ with self.assertRaisesRegexp(
+ ValueError, 'all categorical_columns must have the same type.*'
+ 'IdentityCategoricalColumn.*HashedCategoricalColumn'):
+ fc.shared_embedding_columns(
+ [categorical_column_a, categorical_column_b, categorical_column_c],
+ dimension=2)
+
+ def test_weighted_categorical_column_ok(self):
+ categorical_column_a = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=3)
+ weighted_categorical_column_a = fc.weighted_categorical_column(
+ categorical_column_a, weight_feature_key='aaa_weights')
+ categorical_column_b = fc.categorical_column_with_identity(
+ key='bbb', num_buckets=3)
+ weighted_categorical_column_b = fc.weighted_categorical_column(
+ categorical_column_b, weight_feature_key='bbb_weights')
+ fc.shared_embedding_columns(
+ [weighted_categorical_column_a, categorical_column_b], dimension=2)
+ fc.shared_embedding_columns(
+ [categorical_column_a, weighted_categorical_column_b], dimension=2)
+ fc.shared_embedding_columns(
+ [weighted_categorical_column_a, weighted_categorical_column_b],
+ dimension=2)
+
+ def test_parse_example(self):
+ a = fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
+ b = fc.categorical_column_with_vocabulary_list(
+ key='bbb', vocabulary_list=('omar', 'stringer', 'marlo'))
+ a_embedded, b_embedded = fc.shared_embedding_columns(
+ [a, b], dimension=2)
+ data = example_pb2.Example(features=feature_pb2.Features(
+ feature={
+ 'aaa':
+ feature_pb2.Feature(bytes_list=feature_pb2.BytesList(
+ value=[b'omar', b'stringer'])),
+ 'bbb':
+ feature_pb2.Feature(bytes_list=feature_pb2.BytesList(
+ value=[b'stringer', b'marlo'])),
+ }))
+ features = parsing_ops.parse_example(
+ serialized=[data.SerializeToString()],
+ features=fc.make_parse_example_spec([a_embedded, b_embedded]))
+ self.assertIn('aaa', features)
+ self.assertIn('bbb', features)
+ with self.test_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [0, 1]],
+ values=np.array([b'omar', b'stringer'], dtype=np.object_),
+ dense_shape=[1, 2]),
+ features['aaa'].eval())
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [0, 1]],
+ values=np.array([b'stringer', b'marlo'], dtype=np.object_),
+ dense_shape=[1, 2]),
+ features['bbb'].eval())
+
+ def test_transform_feature(self):
+ a = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
+ b = fc.categorical_column_with_identity(key='bbb', num_buckets=3)
+ a_embedded, b_embedded = fc.shared_embedding_columns(
+ [a, b], dimension=2)
+ features = {
+ 'aaa': sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 1, 0),
+ dense_shape=(2, 2)),
+ 'bbb': sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(1, 2, 1),
+ dense_shape=(2, 2)),
+ }
+ outputs = _transform_features(features, [a, a_embedded, b, b_embedded],
+ None)
+ output_a = outputs[a]
+ output_a_embedded = outputs[a_embedded]
+ output_b = outputs[b]
+ output_b_embedded = outputs[b_embedded]
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self, output_a.eval(), output_a_embedded.eval())
+ _assert_sparse_tensor_value(
+ self, output_b.eval(), output_b_embedded.eval())
+
+ def test_get_dense_tensor(self):
+ # Inputs.
+ vocabulary_size = 3
+ # -1 values are ignored.
+ input_a = np.array(
+ [[2, -1, -1], # example 0, ids [2]
+ [0, 1, -1]]) # example 1, ids [0, 1]
+ input_b = np.array(
+ [[0, -1, -1], # example 0, ids [0]
+ [-1, -1, -1]]) # example 1, ids []
+ input_features = {
+ 'aaa': input_a,
+ 'bbb': input_b
+ }
+
+ # Embedding variable.
+ embedding_dimension = 2
+ embedding_values = (
+ (1., 2.), # id 0
+ (3., 5.), # id 1
+ (7., 11.) # id 2
+ )
+ def _initializer(shape, dtype, partition_info):
+ self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
+ self.assertEqual(dtypes.float32, dtype)
+ self.assertIsNone(partition_info)
+ return embedding_values
+
+ # Expected lookup result, using combiner='mean'.
+ expected_lookups_a = (
+ # example 0:
+ (7., 11.), # ids [2], embedding = [7, 11]
+ # example 1:
+ (2., 3.5), # ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
+ )
+ expected_lookups_b = (
+ # example 0:
+ (1., 2.), # ids [0], embedding = [1, 2]
+ # example 1:
+ (0., 0.), # ids [], embedding = [0, 0]
+ )
+
+ # Build columns.
+ categorical_column_a = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ categorical_column_b = fc.categorical_column_with_identity(
+ key='bbb', num_buckets=vocabulary_size)
+ embedding_column_a, embedding_column_b = fc.shared_embedding_columns(
+ [categorical_column_a, categorical_column_b],
+ dimension=embedding_dimension, initializer=_initializer)
+ state_manager = _TestSharedEmbeddingStateManager()
+
+ # Provide sparse input and get dense result.
+ embedding_lookup_a = embedding_column_a.get_dense_tensor(
+ FeatureTransformationCache(input_features), state_manager)
+ embedding_lookup_b = embedding_column_b.get_dense_tensor(
+ FeatureTransformationCache(input_features), state_manager)
+
+ # Assert expected embedding variable and lookups.
+ global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ self.assertItemsEqual(('embedding_weights:0',),
+ tuple([v.name for v in global_vars]))
+ embedding_var = global_vars[0]
+ with _initialized_session():
+ self.assertAllEqual(embedding_values, embedding_var.eval())
+ self.assertAllEqual(expected_lookups_a, embedding_lookup_a.eval())
+ self.assertAllEqual(expected_lookups_b, embedding_lookup_b.eval())
+
+ def DISABLED_test_get_dense_tensor_weight_collections(self):
+ # Inputs.
+ vocabulary_size = 3
+ # -1 values are ignored.
+ input_a = np.array([
+ [2, -1, -1], # example 0, ids [2]
+ [0, 1, -1]
+ ]) # example 1, ids [0, 1]
+ input_b = np.array([
+ [0, -1, -1], # example 0, ids [0]
+ [-1, -1, -1]
+ ]) # example 1, ids []
+ input_features = {'aaa': input_a, 'bbb': input_b}
+
+ # Embedding variable.
+ embedding_dimension = 2
+ embedding_values = (
+ (1., 2.), # id 0
+ (3., 5.), # id 1
+ (7., 11.) # id 2
+ )
+
+ def _initializer(shape, dtype, partition_info):
+ self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
+ self.assertEqual(dtypes.float32, dtype)
+ self.assertIsNone(partition_info)
+ return embedding_values
+
+ # Build columns.
+ categorical_column_a = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ categorical_column_b = fc.categorical_column_with_identity(
+ key='bbb', num_buckets=vocabulary_size)
+ embedding_column_a, embedding_column_b = fc.shared_embedding_columns(
+ [categorical_column_a, categorical_column_b],
+ dimension=embedding_dimension,
+ initializer=_initializer)
+
+ fc.input_layer(
+ input_features, [embedding_column_a, embedding_column_b],
+ weight_collections=('my_vars',))
+
+ # Assert expected embedding variable and lookups.
+ global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ self.assertItemsEqual(
+ ('input_layer/aaa_bbb_shared_embedding/embedding_weights:0',),
+ tuple(v.name for v in global_vars))
+ my_vars = ops.get_collection('my_vars')
+ self.assertItemsEqual(
+ ('input_layer/aaa_bbb_shared_embedding/embedding_weights:0',),
+ tuple(v.name for v in my_vars))
+
+ def test_get_dense_tensor_placeholder_inputs(self):
+ # Inputs.
+ vocabulary_size = 3
+ # -1 values are ignored.
+ input_a = np.array(
+ [[2, -1, -1], # example 0, ids [2]
+ [0, 1, -1]]) # example 1, ids [0, 1]
+ input_b = np.array(
+ [[0, -1, -1], # example 0, ids [0]
+ [-1, -1, -1]]) # example 1, ids []
+ # Specify shape, because dense input must have rank specified.
+ input_a_placeholder = array_ops.placeholder(
+ dtype=dtypes.int64, shape=[None, 3])
+ input_b_placeholder = array_ops.placeholder(
+ dtype=dtypes.int64, shape=[None, 3])
+ input_features = {
+ 'aaa': input_a_placeholder,
+ 'bbb': input_b_placeholder,
+ }
+ feed_dict = {
+ input_a_placeholder: input_a,
+ input_b_placeholder: input_b,
+ }
+
+ # Embedding variable.
+ embedding_dimension = 2
+ embedding_values = (
+ (1., 2.), # id 0
+ (3., 5.), # id 1
+ (7., 11.) # id 2
+ )
+ def _initializer(shape, dtype, partition_info):
+ self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
+ self.assertEqual(dtypes.float32, dtype)
+ self.assertIsNone(partition_info)
+ return embedding_values
+
+ # Build columns.
+ categorical_column_a = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ categorical_column_b = fc.categorical_column_with_identity(
+ key='bbb', num_buckets=vocabulary_size)
+ embedding_column_a, embedding_column_b = fc.shared_embedding_columns(
+ [categorical_column_a, categorical_column_b],
+ dimension=embedding_dimension, initializer=_initializer)
+ state_manager = _TestSharedEmbeddingStateManager()
+
+ # Provide sparse input and get dense result.
+ embedding_lookup_a = embedding_column_a.get_dense_tensor(
+ FeatureTransformationCache(input_features), state_manager)
+ embedding_lookup_b = embedding_column_b.get_dense_tensor(
+ FeatureTransformationCache(input_features), state_manager)
+
+ with _initialized_session() as sess:
+ sess.run([embedding_lookup_a, embedding_lookup_b], feed_dict=feed_dict)
+
+ def test_linear_model(self):
+ # Inputs.
+ batch_size = 2
+ vocabulary_size = 3
+ # -1 values are ignored.
+ input_a = np.array(
+ [[2, -1, -1], # example 0, ids [2]
+ [0, 1, -1]]) # example 1, ids [0, 1]
+ input_b = np.array(
+ [[0, -1, -1], # example 0, ids [0]
+ [-1, -1, -1]]) # example 1, ids []
+
+ # Embedding variable.
+ embedding_dimension = 2
+ embedding_shape = (vocabulary_size, embedding_dimension)
+ zeros_embedding_values = np.zeros(embedding_shape)
+ def _initializer(shape, dtype, partition_info):
+ self.assertAllEqual(embedding_shape, shape)
+ self.assertEqual(dtypes.float32, dtype)
+ self.assertIsNone(partition_info)
+ return zeros_embedding_values
+
+ # Build columns.
+ categorical_column_a = fc_old.categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ categorical_column_b = fc_old.categorical_column_with_identity(
+ key='bbb', num_buckets=vocabulary_size)
+ embedding_column_a, embedding_column_b = fc_old.shared_embedding_columns(
+ [categorical_column_a, categorical_column_b],
+ dimension=embedding_dimension,
+ initializer=_initializer)
+
+ with ops.Graph().as_default():
+ predictions = fc.linear_model({
+ categorical_column_a.name: input_a,
+ categorical_column_b.name: input_b,
+ }, (embedding_column_a, embedding_column_b))
+ # Linear weights do not follow the column name. But this is a rare use
+ # case, and fixing it would add too much complexity to the code.
+ expected_var_names = (
+ 'linear_model/bias_weights:0',
+ 'linear_model/aaa_bbb_shared_embedding/weights:0',
+ 'linear_model/aaa_bbb_shared_embedding/embedding_weights:0',
+ 'linear_model/aaa_bbb_shared_embedding_1/weights:0',
+ )
+ self.assertItemsEqual(
+ expected_var_names,
+ [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
+ trainable_vars = {
+ v.name: v for v in ops.get_collection(
+ ops.GraphKeys.TRAINABLE_VARIABLES)
+ }
+ self.assertItemsEqual(expected_var_names, trainable_vars.keys())
+ bias = trainable_vars['linear_model/bias_weights:0']
+ embedding_weights = trainable_vars[
+ 'linear_model/aaa_bbb_shared_embedding/embedding_weights:0']
+ linear_weights_a = trainable_vars[
+ 'linear_model/aaa_bbb_shared_embedding/weights:0']
+ linear_weights_b = trainable_vars[
+ 'linear_model/aaa_bbb_shared_embedding_1/weights:0']
+ with _initialized_session():
+ # Predictions with all zero weights.
+ self.assertAllClose(np.zeros((1,)), bias.eval())
+ self.assertAllClose(zeros_embedding_values, embedding_weights.eval())
+ self.assertAllClose(
+ np.zeros((embedding_dimension, 1)), linear_weights_a.eval())
+ self.assertAllClose(
+ np.zeros((embedding_dimension, 1)), linear_weights_b.eval())
+ self.assertAllClose(np.zeros((batch_size, 1)), predictions.eval())
+
+ # Predictions with all non-zero weights.
+ embedding_weights.assign((
+ (1., 2.), # id 0
+ (3., 5.), # id 1
+ (7., 11.) # id 2
+ )).eval()
+ linear_weights_a.assign(((4.,), (6.,))).eval()
+ # example 0, ids [2], embedding[0] = [7, 11]
+ # example 1, ids [0, 1], embedding[1] = mean([1, 2] + [3, 5]) = [2, 3.5]
+ # sum(embeddings * linear_weights)
+ # = [4*7 + 6*11, 4*2 + 6*3.5] = [94, 29]
+ linear_weights_b.assign(((3.,), (5.,))).eval()
+ # example 0, ids [0], embedding[0] = [1, 2]
+ # example 1, ids [], embedding[1] = 0, 0]
+ # sum(embeddings * linear_weights)
+ # = [3*1 + 5*2, 3*0 +5*0] = [13, 0]
+ self.assertAllClose([[94. + 13.], [29.]], predictions.eval())
+
+ def test_keras_linear_model(self):
+ # Inputs.
+ batch_size = 2
+ vocabulary_size = 3
+ # -1 values are ignored.
+ input_a = np.array([
+ [2, -1, -1], # example 0, ids [2]
+ [0, 1, -1]
+ ]) # example 1, ids [0, 1]
+ input_b = np.array([
+ [0, -1, -1], # example 0, ids [0]
+ [-1, -1, -1]
+ ]) # example 1, ids []
+
+ # Embedding variable.
+ embedding_dimension = 2
+ embedding_shape = (vocabulary_size, embedding_dimension)
+ zeros_embedding_values = np.zeros(embedding_shape)
+
+ def _initializer(shape, dtype, partition_info):
+ self.assertAllEqual(embedding_shape, shape)
+ self.assertEqual(dtypes.float32, dtype)
+ self.assertIsNone(partition_info)
+ return zeros_embedding_values
+
+ # Build columns.
+ categorical_column_a = fc_old.categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ categorical_column_b = fc_old.categorical_column_with_identity(
+ key='bbb', num_buckets=vocabulary_size)
+ embedding_column_a, embedding_column_b = fc_old.shared_embedding_columns(
+ [categorical_column_a, categorical_column_b],
+ dimension=embedding_dimension,
+ initializer=_initializer)
+
+ with ops.Graph().as_default():
+ predictions = get_keras_linear_model_predictions({
+ categorical_column_a.name: input_a,
+ categorical_column_b.name: input_b,
+ }, (embedding_column_a, embedding_column_b))
+ # Linear weights do not follow the column name. But this is a rare use
+ # case, and fixing it would add too much complexity to the code.
+ expected_var_names = (
+ 'linear_model/bias_weights:0',
+ 'linear_model/aaa_bbb_shared_embedding/weights:0',
+ 'linear_model/aaa_bbb_shared_embedding/embedding_weights:0',
+ 'linear_model/aaa_bbb_shared_embedding_1/weights:0',
+ )
+ self.assertItemsEqual(
+ expected_var_names,
+ [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
+ trainable_vars = {
+ v.name: v
+ for v in ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ }
+ self.assertItemsEqual(expected_var_names, trainable_vars.keys())
+ bias = trainable_vars['linear_model/bias_weights:0']
+ embedding_weights = trainable_vars[
+ 'linear_model/aaa_bbb_shared_embedding/embedding_weights:0']
+ linear_weights_a = trainable_vars[
+ 'linear_model/aaa_bbb_shared_embedding/weights:0']
+ linear_weights_b = trainable_vars[
+ 'linear_model/aaa_bbb_shared_embedding_1/weights:0']
+ with _initialized_session():
+ # Predictions with all zero weights.
+ self.assertAllClose(np.zeros((1,)), bias.eval())
+ self.assertAllClose(zeros_embedding_values, embedding_weights.eval())
+ self.assertAllClose(
+ np.zeros((embedding_dimension, 1)), linear_weights_a.eval())
+ self.assertAllClose(
+ np.zeros((embedding_dimension, 1)), linear_weights_b.eval())
+ self.assertAllClose(np.zeros((batch_size, 1)), predictions.eval())
+
+ # Predictions with all non-zero weights.
+ embedding_weights.assign((
+ (1., 2.), # id 0
+ (3., 5.), # id 1
+ (7., 11.) # id 2
+ )).eval()
+ linear_weights_a.assign(((4.,), (6.,))).eval()
+ # example 0, ids [2], embedding[0] = [7, 11]
+ # example 1, ids [0, 1], embedding[1] = mean([1, 2] + [3, 5]) = [2, 3.5]
+ # sum(embeddings * linear_weights)
+ # = [4*7 + 6*11, 4*2 + 6*3.5] = [94, 29]
+ linear_weights_b.assign(((3.,), (5.,))).eval()
+ # example 0, ids [0], embedding[0] = [1, 2]
+ # example 1, ids [], embedding[1] = 0, 0]
+ # sum(embeddings * linear_weights)
+ # = [3*1 + 5*2, 3*0 +5*0] = [13, 0]
+ self.assertAllClose([[94. + 13.], [29.]], predictions.eval())
+
+ def _test_input_layer(self, trainable=True):
+ # Inputs.
+ vocabulary_size = 3
+ sparse_input_a = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ indices=((0, 0), (1, 0), (1, 4)),
+ values=(2, 0, 1),
+ dense_shape=(2, 5))
+ sparse_input_b = sparse_tensor.SparseTensorValue(
+ # example 0, ids [0]
+ # example 1, ids []
+ indices=((0, 0),),
+ values=(0,),
+ dense_shape=(2, 5))
+
+ # Embedding variable.
+ embedding_dimension = 2
+ embedding_values = (
+ (1., 2.), # id 0
+ (3., 5.), # id 1
+ (7., 11.) # id 2
+ )
+ def _initializer(shape, dtype, partition_info):
+ self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
+ self.assertEqual(dtypes.float32, dtype)
+ self.assertIsNone(partition_info)
+ return embedding_values
+
+ # Expected lookup result, using combiner='mean'.
+ expected_lookups = (
+ # example 0:
+ # A ids [2], embedding = [7, 11]
+ # B ids [0], embedding = [1, 2]
+ (7., 11., 1., 2.),
+ # example 1:
+ # A ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
+ # B ids [], embedding = [0, 0]
+ (2., 3.5, 0., 0.),
+ )
+
+ # Build columns.
+ categorical_column_a = fc_old.categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ categorical_column_b = fc_old.categorical_column_with_identity(
+ key='bbb', num_buckets=vocabulary_size)
+ embedding_column_a, embedding_column_b = fc_old.shared_embedding_columns(
+ [categorical_column_a, categorical_column_b],
+ dimension=embedding_dimension,
+ initializer=_initializer,
+ trainable=trainable)
+
+ # Provide sparse input and get dense result.
+ input_layer = fc.input_layer(
+ features={'aaa': sparse_input_a, 'bbb': sparse_input_b},
+ feature_columns=(embedding_column_b, embedding_column_a))
+
+ # Assert expected embedding variable and lookups.
+ global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ self.assertItemsEqual(
+ ['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'],
+ tuple([v.name for v in global_vars]))
+ trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ if trainable:
+ self.assertItemsEqual(
+ ['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'],
+ tuple([v.name for v in trainable_vars]))
+ else:
+ self.assertItemsEqual([], tuple([v.name for v in trainable_vars]))
+ shared_embedding_vars = global_vars
+ with _initialized_session():
+ self.assertAllEqual(embedding_values, shared_embedding_vars[0].eval())
+ self.assertAllEqual(expected_lookups, input_layer.eval())
+
+ def test_input_layer(self):
+ self._test_input_layer()
+
+ def test_input_layer_no_trainable(self):
+ self._test_input_layer(trainable=False)
+
+
+class WeightedCategoricalColumnTest(test.TestCase):
+
+ def test_defaults(self):
+ column = fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
+ key='ids', num_buckets=3),
+ weight_feature_key='values')
+ self.assertEqual('ids_weighted_by_values', column.name)
+ self.assertEqual(3, column.num_buckets)
+ self.assertEqual({
+ 'ids': parsing_ops.VarLenFeature(dtypes.int64),
+ 'values': parsing_ops.VarLenFeature(dtypes.float32)
+ }, column.parse_example_spec)
+
+ def test_deep_copy(self):
+ """Tests deepcopy of categorical_column_with_hash_bucket."""
+ original = fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
+ key='ids', num_buckets=3),
+ weight_feature_key='values')
+ for column in (original, copy.deepcopy(original)):
+ self.assertEqual('ids_weighted_by_values', column.name)
+ self.assertEqual(3, column.num_buckets)
+ self.assertEqual({
+ 'ids': parsing_ops.VarLenFeature(dtypes.int64),
+ 'values': parsing_ops.VarLenFeature(dtypes.float32)
+ }, column.parse_example_spec)
+
+ def test_invalid_dtype_none(self):
+ with self.assertRaisesRegexp(ValueError, 'is not convertible to float'):
+ fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
+ key='ids', num_buckets=3),
+ weight_feature_key='values',
+ dtype=None)
+
+ def test_invalid_dtype_string(self):
+ with self.assertRaisesRegexp(ValueError, 'is not convertible to float'):
+ fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
+ key='ids', num_buckets=3),
+ weight_feature_key='values',
+ dtype=dtypes.string)
+
+ def test_invalid_input_dtype(self):
+ column = fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
+ key='ids', num_buckets=3),
+ weight_feature_key='values')
+ strings = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('omar', 'stringer', 'marlo'),
+ dense_shape=(2, 2))
+ with self.assertRaisesRegexp(ValueError, 'Bad dtype'):
+ _transform_features({'ids': strings, 'values': strings}, (column,), None)
+
+ def test_column_name_collision(self):
+ with self.assertRaisesRegexp(ValueError, r'Parse config.*already exists'):
+ fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
+ key='aaa', num_buckets=3),
+ weight_feature_key='aaa').parse_example_spec()
+
+ def test_missing_weights(self):
+ column = fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
+ key='ids', num_buckets=3),
+ weight_feature_key='values')
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('omar', 'stringer', 'marlo'),
+ dense_shape=(2, 2))
+ with self.assertRaisesRegexp(
+ ValueError, 'values is not in features dictionary'):
+ _transform_features({'ids': inputs}, (column,), None)
+
+ def test_parse_example(self):
+ a = fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
+ a_weighted = fc.weighted_categorical_column(a, weight_feature_key='weights')
+ data = example_pb2.Example(features=feature_pb2.Features(
+ feature={
+ 'aaa':
+ feature_pb2.Feature(bytes_list=feature_pb2.BytesList(
+ value=[b'omar', b'stringer'])),
+ 'weights':
+ feature_pb2.Feature(float_list=feature_pb2.FloatList(
+ value=[1., 10.]))
+ }))
+ features = parsing_ops.parse_example(
+ serialized=[data.SerializeToString()],
+ features=fc.make_parse_example_spec([a_weighted]))
+ self.assertIn('aaa', features)
+ self.assertIn('weights', features)
+ with self.test_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [0, 1]],
+ values=np.array([b'omar', b'stringer'], dtype=np.object_),
+ dense_shape=[1, 2]),
+ features['aaa'].eval())
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [0, 1]],
+ values=np.array([1., 10.], dtype=np.float32),
+ dense_shape=[1, 2]),
+ features['weights'].eval())
+
+ def test_transform_features(self):
+ column = fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
+ key='ids', num_buckets=3),
+ weight_feature_key='values')
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 1, 0),
+ dense_shape=(2, 2))
+ weights = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0.5, 1.0, 0.1),
+ dense_shape=(2, 2))
+ id_tensor, weight_tensor = _transform_features({
+ 'ids': inputs,
+ 'values': weights,
+ }, (column,), None)[column]
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array(inputs.values, dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_tensor.eval())
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=weights.indices,
+ values=np.array(weights.values, dtype=np.float32),
+ dense_shape=weights.dense_shape),
+ weight_tensor.eval())
+
+ def test_transform_features_dense_input(self):
+ column = fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
+ key='ids', num_buckets=3),
+ weight_feature_key='values')
+ weights = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0.5, 1.0, 0.1),
+ dense_shape=(2, 2))
+ id_tensor, weight_tensor = _transform_features({
+ 'ids': ((0, -1), (1, 0)),
+ 'values': weights,
+ }, (column,), None)[column]
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=np.array((0, 1, 0), dtype=np.int64),
+ dense_shape=(2, 2)),
+ id_tensor.eval())
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=weights.indices,
+ values=np.array(weights.values, dtype=np.float32),
+ dense_shape=weights.dense_shape),
+ weight_tensor.eval())
+
+ def test_transform_features_dense_weights(self):
+ column = fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
+ key='ids', num_buckets=3),
+ weight_feature_key='values')
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(2, 1, 0),
+ dense_shape=(2, 2))
+ id_tensor, weight_tensor = _transform_features({
+ 'ids': inputs,
+ 'values': ((.5, 0.), (1., .1)),
+ }, (column,), None)[column]
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array(inputs.values, dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_tensor.eval())
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=np.array((.5, 1., .1), dtype=np.float32),
+ dense_shape=(2, 2)),
+ weight_tensor.eval())
+
+ def test_keras_linear_model(self):
+ column = fc_old.weighted_categorical_column(
+ categorical_column=fc_old.categorical_column_with_identity(
+ key='ids', num_buckets=3),
+ weight_feature_key='values')
+ with ops.Graph().as_default():
+ predictions = get_keras_linear_model_predictions({
+ 'ids':
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 2, 1),
+ dense_shape=(2, 2)),
+ 'values':
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(.5, 1., .1),
+ dense_shape=(2, 2))
+ }, (column,))
+ bias = get_linear_model_bias()
+ weight_var = get_linear_model_column_var(column)
+ with _initialized_session():
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ weight_var.assign(((1.,), (2.,), (3.,))).eval()
+ # weight_var[0] * weights[0, 0] = 1 * .5 = .5
+ # weight_var[2] * weights[1, 0] + weight_var[1] * weights[1, 1]
+ # = 3*1 + 2*.1 = 3+.2 = 3.2
+ self.assertAllClose(((.5,), (3.2,)), predictions.eval())
+
+ def test_keras_linear_model_mismatched_shape(self):
+ column = fc_old.weighted_categorical_column(
+ categorical_column=fc_old.categorical_column_with_identity(
+ key='ids', num_buckets=3),
+ weight_feature_key='values')
+ with ops.Graph().as_default():
+ with self.assertRaisesRegexp(ValueError,
+ r'Dimensions.*are not compatible'):
+ get_keras_linear_model_predictions({
+ 'ids':
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 2, 1),
+ dense_shape=(2, 2)),
+ 'values':
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (0, 1), (1, 0), (1, 1)),
+ values=(.5, 11., 1., .1),
+ dense_shape=(2, 2))
+ }, (column,))
+
+ def test_keras_linear_model_mismatched_dense_values(self):
+ column = fc_old.weighted_categorical_column(
+ categorical_column=fc_old.categorical_column_with_identity(
+ key='ids', num_buckets=3),
+ weight_feature_key='values')
+ with ops.Graph().as_default():
+ predictions = get_keras_linear_model_predictions(
+ {
+ 'ids':
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 2, 1),
+ dense_shape=(2, 2)),
+ 'values': ((.5,), (1.,))
+ }, (column,),
+ sparse_combiner='mean')
+ # Disabling the constant folding optimizer here since it changes the
+ # error message differently on CPU and GPU.
+ config = config_pb2.ConfigProto()
+ config.graph_options.rewrite_options.constant_folding = (
+ rewriter_config_pb2.RewriterConfig.OFF)
+ with _initialized_session(config):
+ with self.assertRaisesRegexp(errors.OpError, 'Incompatible shapes'):
+ predictions.eval()
+
+ def test_keras_linear_model_mismatched_dense_shape(self):
+ column = fc_old.weighted_categorical_column(
+ categorical_column=fc_old.categorical_column_with_identity(
+ key='ids', num_buckets=3),
+ weight_feature_key='values')
+ with ops.Graph().as_default():
+ predictions = get_keras_linear_model_predictions({
+ 'ids':
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 2, 1),
+ dense_shape=(2, 2)),
+ 'values': ((.5,), (1.,), (.1,))
+ }, (column,))
+ bias = get_linear_model_bias()
+ weight_var = get_linear_model_column_var(column)
+ with _initialized_session():
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ weight_var.assign(((1.,), (2.,), (3.,))).eval()
+ # weight_var[0] * weights[0, 0] = 1 * .5 = .5
+ # weight_var[2] * weights[1, 0] + weight_var[1] * weights[1, 1]
+ # = 3*1 + 2*.1 = 3+.2 = 3.2
+ self.assertAllClose(((.5,), (3.2,)), predictions.eval())
+
+ def test_linear_model(self):
+ column = fc_old.weighted_categorical_column(
+ categorical_column=fc_old.categorical_column_with_identity(
+ key='ids', num_buckets=3),
+ weight_feature_key='values')
+ with ops.Graph().as_default():
+ predictions = fc.linear_model({
+ 'ids': sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 2, 1),
+ dense_shape=(2, 2)),
+ 'values': sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(.5, 1., .1),
+ dense_shape=(2, 2))
+ }, (column,))
+ bias = get_linear_model_bias()
+ weight_var = get_linear_model_column_var(column)
+ with _initialized_session():
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ weight_var.assign(((1.,), (2.,), (3.,))).eval()
+ # weight_var[0] * weights[0, 0] = 1 * .5 = .5
+ # weight_var[2] * weights[1, 0] + weight_var[1] * weights[1, 1]
+ # = 3*1 + 2*.1 = 3+.2 = 3.2
+ self.assertAllClose(((.5,), (3.2,)), predictions.eval())
+
+ def test_linear_model_mismatched_shape(self):
+ column = fc_old.weighted_categorical_column(
+ categorical_column=fc_old.categorical_column_with_identity(
+ key='ids', num_buckets=3),
+ weight_feature_key='values')
+ with ops.Graph().as_default():
+ with self.assertRaisesRegexp(
+ ValueError, r'Dimensions.*are not compatible'):
+ fc.linear_model({
+ 'ids': sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 2, 1),
+ dense_shape=(2, 2)),
+ 'values': sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (0, 1), (1, 0), (1, 1)),
+ values=(.5, 11., 1., .1),
+ dense_shape=(2, 2))
+ }, (column,))
+
+ def test_linear_model_mismatched_dense_values(self):
+ column = fc_old.weighted_categorical_column(
+ categorical_column=fc_old.categorical_column_with_identity(
+ key='ids', num_buckets=3),
+ weight_feature_key='values')
+ with ops.Graph().as_default():
+ predictions = fc.linear_model(
+ {
+ 'ids':
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 2, 1),
+ dense_shape=(2, 2)),
+ 'values': ((.5,), (1.,))
+ }, (column,),
+ sparse_combiner='mean')
+ # Disabling the constant folding optimizer here since it changes the
+ # error message differently on CPU and GPU.
+ config = config_pb2.ConfigProto()
+ config.graph_options.rewrite_options.constant_folding = (
+ rewriter_config_pb2.RewriterConfig.OFF)
+ with _initialized_session(config):
+ with self.assertRaisesRegexp(errors.OpError, 'Incompatible shapes'):
+ predictions.eval()
+
+ def test_linear_model_mismatched_dense_shape(self):
+ column = fc_old.weighted_categorical_column(
+ categorical_column=fc_old.categorical_column_with_identity(
+ key='ids', num_buckets=3),
+ weight_feature_key='values')
+ with ops.Graph().as_default():
+ predictions = fc.linear_model({
+ 'ids': sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 2, 1),
+ dense_shape=(2, 2)),
+ 'values': ((.5,), (1.,), (.1,))
+ }, (column,))
+ bias = get_linear_model_bias()
+ weight_var = get_linear_model_column_var(column)
+ with _initialized_session():
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ weight_var.assign(((1.,), (2.,), (3.,))).eval()
+ # weight_var[0] * weights[0, 0] = 1 * .5 = .5
+ # weight_var[2] * weights[1, 0] + weight_var[1] * weights[1, 1]
+ # = 3*1 + 2*.1 = 3+.2 = 3.2
+ self.assertAllClose(((.5,), (3.2,)), predictions.eval())
+
+ # TODO(ptucker): Add test with embedding of weighted categorical.
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/framework/common_shapes.py b/tensorflow/python/framework/common_shapes.py
index 3c5aebbce8..40788e24c4 100644
--- a/tensorflow/python/framework/common_shapes.py
+++ b/tensorflow/python/framework/common_shapes.py
@@ -28,6 +28,18 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
+def has_fully_defined_shape(tensor):
+ """Returns true if tensor has a fully defined shape."""
+ return isinstance(tensor, ops.EagerTensor) or tensor.shape.is_fully_defined()
+
+
+def rank(tensor):
+ """Return a rank if it is a tensor, else return None."""
+ if isinstance(tensor, ops.Tensor):
+ return tensor._rank() # pylint: disable=protected-access
+ return None
+
+
def scalar_shape(unused_op):
"""Shape function for ops that output a scalar value."""
return [tensor_shape.scalar()]
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index 15e41ba91f..1707f929b8 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -537,19 +537,25 @@ class FunctionTest(test.TestCase):
def testResourceVarAsImplicitInput(self):
g = ops.Graph()
with g.as_default(), ops.device("cpu:0"):
+ expected_type = dtypes.float32
+ expected_shape = tensor_shape.TensorShape((4, 4))
v = variable_scope.get_variable(
- "var", (4, 4), dtypes.float32, use_resource=True)
+ "var", expected_shape, expected_type, use_resource=True)
@function.Defun()
def Foo():
- return array_ops.identity(v)
+ captured = array_ops.identity(v)
+ self.assertEqual(expected_type, captured.dtype)
+ self.assertEqual(expected_shape, captured.shape)
+ return captured, array_ops.shape(captured)
- y = v.value()
- z = Foo()
+ expected_val = v.value()
+ actual_val, actual_shape = Foo()
with self.test_session(graph=g):
v.initializer.run()
- self.assertAllEqual(y.eval(), z.eval())
+ self.assertAllEqual(expected_val.eval(), actual_val.eval())
+ self.assertAllEqual(expected_shape, actual_shape.eval())
def testDefineErrors(self):
with ops.Graph().as_default():
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index cf0b1e36fb..c4f58f0847 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -20,7 +20,6 @@ from __future__ import print_function
import collections
import copy
-import linecache
import os
import re
import sys
@@ -48,7 +47,9 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import op_def_registry
from tensorflow.python.framework import registry
+from tensorflow.python.util import tf_stack
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import traceable_stack
from tensorflow.python.framework import versions
from tensorflow.python.ops import control_flow_util
from tensorflow.python.platform import app
@@ -706,7 +707,7 @@ class _EagerTensorBase(Tensor):
"""
if self.dtype == dtypes.resource:
raise ValueError("Resource handles are not convertible to numpy.")
- return self.cpu()._numpy() # pylint: disable=protected-access
+ return self._cpu_nograd()._numpy() # pylint: disable=protected-access
# __int__ and __float__ may copy the tensor to CPU and
# only work for scalars; values are cast as per numpy.
@@ -780,8 +781,8 @@ class _EagerTensorBase(Tensor):
def _override_operator(name, func):
setattr(_EagerTensorBase, name, func)
- def _copy(self, ctx=None, device_name=None):
- """Copies tensor to dest device."""
+ def _copy_nograd(self, ctx=None, device_name=None):
+ """Copies tensor to dest device, but doesn't record the operation."""
# pylint: disable=protected-access
# Creates a new tensor on the dest device.
if ctx is None:
@@ -793,7 +794,11 @@ class _EagerTensorBase(Tensor):
new_tensor = self._copy_to_device(context=ctx._handle, device=device_name)
except core._NotOkStatusException as e:
six.raise_from(core._status_to_exception(e.code, e.message), None)
+ return new_tensor
+ def _copy(self, ctx=None, device_name=None):
+ """Copies tensor to dest device."""
+ new_tensor = self._copy_nograd(ctx, device_name)
# Record the copy on tape and define backprop copy as well.
if context.executing_eagerly():
self_device = self.device
@@ -824,6 +829,16 @@ class _EagerTensorBase(Tensor):
"""Returns the number of Tensor dimensions."""
return self.shape.ndims
+ def _cpu_nograd(self):
+ """A copy of this Tensor with contents backed by host memory.
+
+ The copy cannot be differentiated through.
+
+ Returns:
+ A CPU-memory backed Tensor object with the same contents as this Tensor.
+ """
+ return self._copy_nograd(context.context(), "CPU:0")
+
def cpu(self):
"""A copy of this Tensor with contents backed by host memory."""
return self._copy(context.context(), "CPU:0")
@@ -1699,7 +1714,7 @@ class Operation(object):
self._id_value = self._graph._next_id() # pylint: disable=protected-access
self._original_op = original_op
- self._traceback = self._graph._extract_stack() # pylint: disable=protected-access
+ self._traceback = tf_stack.extract_stack()
self._control_flow_context = self.graph._get_control_flow_context() # pylint: disable=protected-access
# Initialize self._c_op.
@@ -2140,7 +2155,7 @@ class Operation(object):
@property
def traceback(self):
"""Returns the call stack from when this operation was constructed."""
- return self._graph._convert_stack(self._traceback) # pylint: disable=protected-access
+ return tf_stack.convert_stack(self._traceback)
@property
def traceback_with_start_lines(self):
@@ -2149,9 +2164,8 @@ class Operation(object):
Returns:
A list of 5-tuples (filename, lineno, name, code, func_start_lineno).
"""
- return self._graph._convert_stack( # pylint: disable=protected-access
- self._traceback,
- include_func_start_lineno=True)
+ return tf_stack.convert_stack(self._traceback,
+ include_func_start_lineno=True)
def _set_attr(self, attr_name, attr_value):
"""Private method used to set an attribute in the node_def."""
@@ -2603,7 +2617,6 @@ def _name_from_scope_name(name):
_MUTATION_LOCK_GROUP = 0
_SESSION_RUN_LOCK_GROUP = 1
-
@tf_export("Graph")
class Graph(object):
"""A TensorFlow computation, represented as a dataflow graph.
@@ -2712,7 +2725,7 @@ class Graph(object):
self._building_function = False
# Stack of colocate_with ops. After switch_to_thread_local(),
# self._thread_local._colocation_stack is used instead.
- self._graph_colocation_stack = []
+ self._graph_colocation_stack = traceable_stack.TraceableStack()
# Set of tensors that are dangerous to feed!
self._unfeedable_tensors = set()
# Set of operations that are dangerous to fetch!
@@ -2752,36 +2765,6 @@ class Graph(object):
"""Temporary hack; can be overridden to force C API usage."""
return _USE_C_API
- def _convert_stack(self, stack, include_func_start_lineno=False):
- """Converts a stack extracted using _extract_stack() to a traceback stack.
-
- Args:
- stack: A list of n 5-tuples,
- (filename, lineno, name, frame_globals, func_start_lineno).
- include_func_start_lineno: True if function start line number should be
- included as the 5th entry in return tuples.
-
- Returns:
- A list of n 4-tuples or 5-tuples
- (filename, lineno, name, code, [optional: func_start_lineno]), where the
- code tuple element is calculated from the corresponding elements of the
- input tuple.
- """
- ret = []
- for (filename, lineno, name, frame_globals, func_start_lineno,
- unused_frame_info) in stack:
- linecache.checkcache(filename)
- line = linecache.getline(filename, lineno, frame_globals)
- if line:
- line = line.strip()
- else:
- line = None
- if include_func_start_lineno:
- ret.append((filename, lineno, name, line, func_start_lineno))
- else:
- ret.append((filename, lineno, name, line))
- return ret
-
# Note: this method is private because the API of tf.Graph() is public and
# frozen, and this functionality is still not ready for public visibility.
@tf_contextlib.contextmanager
@@ -2807,46 +2790,6 @@ class Graph(object):
def _variable_creator_stack(self, variable_creator_stack):
self._thread_local._variable_creator_stack = variable_creator_stack
- def _extract_stack(self):
- """A lightweight, extensible re-implementation of traceback.extract_stack.
-
- NOTE(mrry): traceback.extract_stack eagerly retrieves the line of code for
- each stack frame using linecache, which results in an abundance of stat()
- calls. This implementation does not retrieve the code, and any consumer
- should apply _convert_stack to the result to obtain a traceback that can
- be formatted etc. using traceback methods.
-
- Derived classes can implement _extract_frame_info() to add extra information
- to the traceback.
-
- Returns:
- A list of 6-tuples
- (filename, lineno, name, frame_globals, func_start_lineno, custom_info)
- corresponding to the call stack of the current thread.
- """
- try:
- raise ZeroDivisionError
- except ZeroDivisionError:
- f = sys.exc_info()[2].tb_frame.f_back
- ret = []
- while f is not None:
- lineno = f.f_lineno
- co = f.f_code
- filename = co.co_filename
- name = co.co_name
- frame_globals = f.f_globals
- func_start_lineno = co.co_firstlineno
- frame_info = self._extract_frame_info(f)
- ret.append((filename, lineno, name, frame_globals, func_start_lineno,
- frame_info))
- f = f.f_back
- ret.reverse()
- return ret
-
- def _extract_frame_info(self, frame): # pylint: disable=unused-argument
- """Extracts custom information from a frame in an op traceback."""
- return None
-
def _check_not_finalized(self):
"""Check if the graph is finalized.
@@ -3287,7 +3230,7 @@ class Graph(object):
if self._colocation_stack:
all_colocation_groups = []
- for colocation_op in self._colocation_stack:
+ for colocation_op in self._colocation_stack.peek_objs():
all_colocation_groups.extend(colocation_op.colocation_groups())
if colocation_op.device:
# Make this device match the device of the colocated op, to provide
@@ -4060,10 +4003,10 @@ class Graph(object):
if ignore_existing:
current_stack = self._colocation_stack
- self._colocation_stack = []
+ self._colocation_stack = traceable_stack.TraceableStack()
if op is not None:
- self._colocation_stack.append(op)
+ self._colocation_stack.push_obj(op, name=op.name, offset=1)
try:
yield
@@ -4071,7 +4014,7 @@ class Graph(object):
# Restore device function stack
self._device_function_stack = device_fn_tmp
if op is not None:
- self._colocation_stack.pop()
+ self._colocation_stack.pop_obj()
# Reset the colocation stack if requested.
if ignore_existing:
@@ -4698,11 +4641,15 @@ class Graph(object):
@property
def _colocation_stack(self):
+ """Return thread-local copy of colocation stack."""
if self._stack_state_is_thread_local:
# This may be called from a thread where colocation_stack doesn't yet
# exist.
if not hasattr(self._thread_local, "_colocation_stack"):
- self._thread_local._colocation_stack = self._graph_colocation_stack[:]
+ stack_copy_for_this_thread = self._graph_colocation_stack.copy()
+ # pylint: disable=protected-access
+ self._thread_local._colocation_stack = stack_copy_for_this_thread
+ # pylint: enable=protected-access
return self._thread_local._colocation_stack
else:
return self._graph_colocation_stack
diff --git a/tensorflow/python/framework/tensor_util_test.py b/tensorflow/python/framework/tensor_util_test.py
index d6edc13643..395cf43b3f 100644
--- a/tensorflow/python/framework/tensor_util_test.py
+++ b/tensorflow/python/framework/tensor_util_test.py
@@ -50,13 +50,13 @@ class TensorUtilTest(test.TestCase):
def testFloatN(self):
t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0])
if sys.byteorder == "big":
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_FLOAT
tensor_shape { dim { size: 3 } }
tensor_content: "A \000\000A\240\000\000A\360\000\000"
""", t)
else:
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_FLOAT
tensor_shape { dim { size: 3 } }
tensor_content: "\000\000 A\000\000\240A\000\000\360A"
@@ -68,13 +68,13 @@ class TensorUtilTest(test.TestCase):
def testFloatTyped(self):
t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0], dtype=dtypes.float32)
if sys.byteorder == "big":
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_FLOAT
tensor_shape { dim { size: 3 } }
tensor_content: "A \000\000A\240\000\000A\360\000\000"
""", t)
else:
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_FLOAT
tensor_shape { dim { size: 3 } }
tensor_content: "\000\000 A\000\000\240A\000\000\360A"
@@ -86,13 +86,13 @@ class TensorUtilTest(test.TestCase):
def testFloatTypeCoerce(self):
t = tensor_util.make_tensor_proto([10, 20, 30], dtype=dtypes.float32)
if sys.byteorder == "big":
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_FLOAT
tensor_shape { dim { size: 3 } }
tensor_content: "A \000\000A\240\000\000A\360\000\000"
""", t)
else:
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_FLOAT
tensor_shape { dim { size: 3 } }
tensor_content: "\000\000 A\000\000\240A\000\000\360A"
@@ -105,13 +105,13 @@ class TensorUtilTest(test.TestCase):
arr = np.asarray([10, 20, 30], dtype="int")
t = tensor_util.make_tensor_proto(arr, dtype=dtypes.float32)
if sys.byteorder == "big":
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_FLOAT
tensor_shape { dim { size: 3 } }
tensor_content: "A \000\000A\240\000\000A\360\000\000"
""", t)
else:
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_FLOAT
tensor_shape { dim { size: 3 } }
tensor_content: "\000\000 A\000\000\240A\000\000\360A"
@@ -123,13 +123,13 @@ class TensorUtilTest(test.TestCase):
def testFloatSizes(self):
t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0], shape=[1, 3])
if sys.byteorder == "big":
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_FLOAT
tensor_shape { dim { size: 1 } dim { size: 3 } }
tensor_content: "A \000\000A\240\000\000A\360\000\000"
""", t)
else:
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_FLOAT
tensor_shape { dim { size: 1 } dim { size: 3 } }
tensor_content: "\000\000 A\000\000\240A\000\000\360A"
@@ -141,13 +141,13 @@ class TensorUtilTest(test.TestCase):
def testFloatSizes2(self):
t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0], shape=[3, 1])
if sys.byteorder == "big":
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_FLOAT
tensor_shape { dim { size: 3 } dim { size: 1 } }
tensor_content: "A \000\000A\240\000\000A\360\000\000"
""", t)
else:
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_FLOAT
tensor_shape { dim { size: 3 } dim { size: 1 } }
tensor_content: "\000\000 A\000\000\240A\000\000\360A"
@@ -169,13 +169,13 @@ class TensorUtilTest(test.TestCase):
t = tensor_util.make_tensor_proto(
np.array([[10.0, 20.0, 30.0]], dtype=np.float64))
if sys.byteorder == "big":
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_DOUBLE
tensor_shape { dim { size: 1 } dim { size: 3 } }
tensor_content: "@$\000\000\000\000\000\000@4\000\000\000\000\000\000@>\000\000\000\000\000\000"
""", t)
else:
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_DOUBLE
tensor_shape { dim { size: 1 } dim { size: 3 } }
tensor_content: "\000\000\000\000\000\000$@\000\000\000\000\000\0004@\000\000\000\000\000\000>@"
@@ -206,13 +206,13 @@ class TensorUtilTest(test.TestCase):
self.assertEquals(np.float32, a.dtype)
self.assertAllClose(np.array([5.0, 20.0, 30.0], dtype=np.float32), a)
if sys.byteorder == "big":
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_FLOAT
tensor_shape { dim { size: 3 } }
tensor_content: "A \000\000A\240\000\000A\360\000\000"
""", t)
else:
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_FLOAT
tensor_shape { dim { size: 3 } }
tensor_content: "\000\000 A\000\000\240A\000\000\360A"
@@ -299,16 +299,16 @@ class TensorUtilTest(test.TestCase):
def testIntNDefaultType(self):
t = tensor_util.make_tensor_proto([10, 20, 30, 40], shape=[2, 2])
if sys.byteorder == "big":
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_INT32
tensor_shape { dim { size: 2 } dim { size: 2 } }
- tensor_content: "\000\000\000\\n\000\000\000\024\000\000\000\036\000\000\000("
+ tensor_content: "\000\000\000\n\000\000\000\024\000\000\000\036\000\000\000("
""", t)
else:
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_INT32
tensor_shape { dim { size: 2 } dim { size: 2 } }
- tensor_content: "\\n\000\000\000\024\000\000\000\036\000\000\000(\000\000\000"
+ tensor_content: "\n\000\000\000\024\000\000\000\036\000\000\000(\000\000\000"
""", t)
a = tensor_util.MakeNdarray(t)
self.assertEquals(np.int32, a.dtype)
@@ -380,16 +380,16 @@ class TensorUtilTest(test.TestCase):
t = tensor_util.make_tensor_proto(
[10, 20, 30], shape=[1, 3], dtype=dtypes.int64)
if sys.byteorder == "big":
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_INT64
tensor_shape { dim { size: 1 } dim { size: 3 } }
- tensor_content: "\000\000\000\000\000\000\000\\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036"
+ tensor_content: "\000\000\000\000\000\000\000\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036"
""", t)
else:
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_INT64
tensor_shape { dim { size: 1 } dim { size: 3 } }
- tensor_content: "\\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036\000\000\000\000\000\000\000"
+ tensor_content: "\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036\000\000\000\000\000\000\000"
""", t)
a = tensor_util.MakeNdarray(t)
self.assertEquals(np.int64, a.dtype)
@@ -398,16 +398,16 @@ class TensorUtilTest(test.TestCase):
def testLongNpArray(self):
t = tensor_util.make_tensor_proto(np.array([10, 20, 30]))
if sys.byteorder == "big":
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_INT64
tensor_shape { dim { size: 3 } }
- tensor_content: "\000\000\000\000\000\000\000\\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036"
+ tensor_content: "\000\000\000\000\000\000\000\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036"
""", t)
else:
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_INT64
tensor_shape { dim { size: 3 } }
- tensor_content: "\\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036\000\000\000\000\000\000\000"
+ tensor_content: "\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036\000\000\000\000\000\000\000"
""", t)
a = tensor_util.MakeNdarray(t)
self.assertEquals(np.int64, a.dtype)
@@ -419,13 +419,13 @@ class TensorUtilTest(test.TestCase):
t = tensor_util.make_tensor_proto(data, dtype=dtypes.qint32)
if sys.byteorder == "big":
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_QINT32
tensor_shape { dim { size: 3 } }
tensor_content: "\000\000\000\025\000\000\000\026\000\000\000\027"
""", t)
else:
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_QINT32
tensor_shape { dim { size: 3 } }
tensor_content: "\025\000\000\000\026\000\000\000\027\000\000\000"
@@ -435,7 +435,7 @@ class TensorUtilTest(test.TestCase):
self.assertAllEqual(np.array(data, dtype=a.dtype), a)
t = tensor_util.make_tensor_proto(data, dtype=dtypes.quint8)
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_QUINT8
tensor_shape { dim { size: 3 } }
tensor_content: "\025\026\027"
@@ -445,7 +445,7 @@ class TensorUtilTest(test.TestCase):
self.assertAllEqual(np.array(data, dtype=a.dtype), a)
t = tensor_util.make_tensor_proto(data, dtype=dtypes.qint8)
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_QINT8
tensor_shape { dim { size: 3 } }
tensor_content: "\025\026\027"
@@ -456,13 +456,13 @@ class TensorUtilTest(test.TestCase):
t = tensor_util.make_tensor_proto(data, dtype=dtypes.quint16)
if sys.byteorder == "big":
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_QUINT16
tensor_shape { dim { size: 3 } }
tensor_content: "\000\025\000\026\000\027"
""", t)
else:
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_QUINT16
tensor_shape { dim { size: 3 } }
tensor_content: "\025\000\026\000\027\000"
@@ -473,13 +473,13 @@ class TensorUtilTest(test.TestCase):
t = tensor_util.make_tensor_proto(data, dtype=dtypes.qint16)
if sys.byteorder == "big":
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_QINT16
tensor_shape { dim { size: 3 } }
tensor_content: "\000\025\000\026\000\027"
""", t)
else:
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_QINT16
tensor_shape { dim { size: 3 } }
tensor_content: "\025\000\026\000\027\000"
diff --git a/tensorflow/python/framework/traceable_stack.py b/tensorflow/python/framework/traceable_stack.py
new file mode 100644
index 0000000000..1b7c6bd7c5
--- /dev/null
+++ b/tensorflow/python/framework/traceable_stack.py
@@ -0,0 +1,135 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A simple stack that associates filename and line numbers with each object."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.util import tf_stack
+
+
+class TraceableObject(object):
+ """Wrap an object together with its the code definition location."""
+
+ # Return codes for the set_filename_and_line_from_caller() method.
+ SUCCESS, HEURISTIC_USED, FAILURE = (0, 1, 2)
+
+ def __init__(self, obj, name=None, filename=None, lineno=None):
+ self.obj = obj
+ self.name = name
+ self.filename = filename
+ self.lineno = lineno
+
+ def set_filename_and_line_from_caller(self, offset=0):
+ """Set filename and line using the caller's stack frame.
+
+ If the requested stack information is not available, a heuristic may
+ be applied and self.HEURISTIC USED will be returned. If the heuristic
+ fails then no change will be made to the filename and lineno members
+ (None by default) and self.FAILURE will be returned.
+
+ Args:
+ offset: Integer. If 0, the caller's stack frame is used. If 1,
+ the caller's caller's stack frame is used. Larger values are
+ permissible but if out-of-range (larger than the number of stack
+ frames available) the outermost stack frame will be used.
+
+ Returns:
+ TraceableObject.SUCCESS if appropriate stack information was found,
+ TraceableObject.HEURISTIC_USED if the offset was larger than the stack,
+ and TraceableObject.FAILURE if the stack was empty.
+ """
+ # Offset is defined in "Args" as relative to the caller. We are one frame
+ # beyond the caller.
+ local_offset = offset + 1
+
+ frame_records = tf_stack.extract_stack()
+ if not frame_records:
+ return self.FAILURE
+ if len(frame_records) >= local_offset:
+ # Negative indexing is one-indexed instead of zero-indexed.
+ negative_offset = -(local_offset + 1)
+ self.filename, self.lineno = frame_records[negative_offset][:2]
+ return self.SUCCESS
+ else:
+ # If the offset is too large then we use the largest offset possible,
+ # meaning we use the outermost stack frame at index 0.
+ self.filename, self.lineno = frame_records[0][:2]
+ return self.HEURISTIC_USED
+
+ def copy_metadata(self):
+ """Return a TraceableObject like this one, but without the object."""
+ return self.__class__(None, name=self.name, filename=self.filename,
+ lineno=self.lineno)
+
+
+class TraceableStack(object):
+ """A stack of TraceableObjects."""
+
+ def __init__(self, existing_stack=None):
+ """Constructor.
+
+ Args:
+ existing_stack: [TraceableObject, ...] If provided, this object will
+ set its new stack to a SHALLOW COPY of existing_stack.
+ """
+ self._stack = existing_stack[:] if existing_stack else []
+
+ def push_obj(self, obj, name=None, offset=0):
+ """Add object to the stack and record its filename and line information.
+
+ Args:
+ obj: An object to store on the stack.
+ name: A name for the object, used for dict keys in get_item_metadata_dict.
+ offset: Integer. If 0, the caller's stack frame is used. If 1,
+ the caller's caller's stack frame is used.
+
+ Returns:
+ TraceableObject.SUCCESS if appropriate stack information was found,
+ TraceableObject.HEURISTIC_USED if the stack was smaller than expected,
+ and TraceableObject.FAILURE if the stack was empty.
+ """
+ traceable_obj = TraceableObject(obj, name=name)
+ self._stack.append(traceable_obj)
+ # Offset is defined in "Args" as relative to the caller. We are 1 frame
+ # beyond the caller and need to compensate.
+ return traceable_obj.set_filename_and_line_from_caller(offset + 1)
+
+ def pop_obj(self):
+ """Remove last-inserted object and return it, without filename/line info."""
+ return self._stack.pop().obj
+
+ def peek_objs(self):
+ """Return list of stored objects ordered newest to oldest."""
+ return [t_obj.obj for t_obj in reversed(self._stack)]
+
+ def peek_traceable_objs(self):
+ """Return list of stored TraceableObjects ordered newest to oldest."""
+ return list(reversed(self._stack))
+
+ def __len__(self):
+ """Return number of items on the stack, and used for truth-value testing."""
+ return len(self._stack)
+
+ def copy(self):
+ """Return a copy of self referencing the same objects but in a new list.
+
+ This method is implemented to support thread-local stacks.
+
+ Returns:
+ TraceableStack with a new list that holds existing objects.
+ """
+ return TraceableStack(self._stack)
diff --git a/tensorflow/python/framework/traceable_stack_test.py b/tensorflow/python/framework/traceable_stack_test.py
new file mode 100644
index 0000000000..3e7876f631
--- /dev/null
+++ b/tensorflow/python/framework/traceable_stack_test.py
@@ -0,0 +1,133 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.python.framework.traceable_stack."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.framework import traceable_stack
+from tensorflow.python.platform import googletest
+from tensorflow.python.util import tf_inspect as inspect
+
+_LOCAL_OBJECT = lambda x: x
+_THIS_FILENAME = inspect.getsourcefile(_LOCAL_OBJECT)
+
+
+class TraceableObjectTest(test_util.TensorFlowTestCase):
+
+ def testSetFilenameAndLineFromCallerUsesCallersStack(self):
+ t_obj = traceable_stack.TraceableObject(17)
+
+ # Do not separate placeholder from the set_filename_and_line_from_caller()
+ # call one line below it as it is used to calculate the latter's line
+ # number.
+ placeholder = lambda x: x
+ result = t_obj.set_filename_and_line_from_caller()
+
+ expected_lineno = inspect.getsourcelines(placeholder)[1] + 1
+ self.assertEqual(expected_lineno, t_obj.lineno)
+ self.assertEqual(_THIS_FILENAME, t_obj.filename)
+ self.assertEqual(t_obj.SUCCESS, result)
+
+ def testSetFilenameAndLineFromCallerRespectsOffset(self):
+
+ def call_set_filename_and_line_from_caller(t_obj):
+ # We expect to retrieve the line number from _our_ caller.
+ return t_obj.set_filename_and_line_from_caller(offset=1)
+
+ t_obj = traceable_stack.TraceableObject(None)
+ # Do not separate placeholder from the
+ # call_set_filename_and_line_from_caller() call one line below it as it is
+ # used to calculate the latter's line number.
+ placeholder = lambda x: x
+ result = call_set_filename_and_line_from_caller(t_obj)
+
+ expected_lineno = inspect.getsourcelines(placeholder)[1] + 1
+ self.assertEqual(expected_lineno, t_obj.lineno)
+ self.assertEqual(t_obj.SUCCESS, result)
+
+ def testSetFilenameAndLineFromCallerHandlesRidiculousOffset(self):
+ t_obj = traceable_stack.TraceableObject('The quick brown fox.')
+ # This line shouldn't die.
+ result = t_obj.set_filename_and_line_from_caller(offset=300)
+
+ # We expect a heuristic to be used because we are not currently 300 frames
+ # down on the stack. The filename and lineno of the outermost frame are not
+ # predictable -- in some environments the filename is this test file, but in
+ # other environments it is not (e.g. due to a test runner calling this
+ # file). Therefore we only test that the called function knows it applied a
+ # heuristic for the ridiculous stack offset.
+ self.assertEqual(t_obj.HEURISTIC_USED, result)
+
+
+class TraceableStackTest(test_util.TensorFlowTestCase):
+
+ def testPushPeekPopObj(self):
+ t_stack = traceable_stack.TraceableStack()
+ t_stack.push_obj(42.0)
+ t_stack.push_obj('hope')
+
+ expected_lifo_peek = ['hope', 42.0]
+ self.assertEqual(expected_lifo_peek, t_stack.peek_objs())
+
+ self.assertEqual('hope', t_stack.pop_obj())
+ self.assertEqual(42.0, t_stack.pop_obj())
+
+ def testPushPopPreserveLifoOrdering(self):
+ t_stack = traceable_stack.TraceableStack()
+ t_stack.push_obj(0)
+ t_stack.push_obj(1)
+ t_stack.push_obj(2)
+ t_stack.push_obj(3)
+
+ obj_3 = t_stack.pop_obj()
+ obj_2 = t_stack.pop_obj()
+ obj_1 = t_stack.pop_obj()
+ obj_0 = t_stack.pop_obj()
+
+ self.assertEqual(3, obj_3)
+ self.assertEqual(2, obj_2)
+ self.assertEqual(1, obj_1)
+ self.assertEqual(0, obj_0)
+
+ def testPushObjSetsFilenameAndLineInfoForCaller(self):
+ t_stack = traceable_stack.TraceableStack()
+
+ # We expect that the line number recorded for the 1-object will come from
+ # the call to t_stack.push_obj(1). Do not separate the next two lines!
+ placeholder_1 = lambda x: x
+ t_stack.push_obj(1)
+
+ # We expect that the line number recorded for the 2-object will come from
+ # the call to call_push_obj() and _not_ the call to t_stack.push_obj().
+ def call_push_obj(obj):
+ t_stack.push_obj(obj, offset=1)
+
+ # Do not separate the next two lines!
+ placeholder_2 = lambda x: x
+ call_push_obj(2)
+
+ expected_lineno_1 = inspect.getsourcelines(placeholder_1)[1] + 1
+ expected_lineno_2 = inspect.getsourcelines(placeholder_2)[1] + 1
+
+ t_obj_2, t_obj_1 = t_stack.peek_traceable_objs()
+ self.assertEqual(expected_lineno_2, t_obj_2.lineno)
+ self.assertEqual(expected_lineno_1, t_obj_1.lineno)
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index 8b6b28bc77..4056818a95 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -451,6 +451,7 @@ cuda_py_test(
"//tensorflow/python:client_testlib",
],
shard_count = 2,
+ tags = ["no_windows_gpu"],
)
py_test(
@@ -720,6 +721,7 @@ py_test(
size = "medium",
srcs = ["preprocessing/image_test.py"],
srcs_version = "PY2AND3",
+ tags = ["nomsan"], # TODO(b/110990716) reenable
deps = [
":keras",
"//tensorflow/python:client_testlib",
diff --git a/tensorflow/python/keras/applications/mobilenet.py b/tensorflow/python/keras/applications/mobilenet.py
index e56c695a28..7285e03963 100644
--- a/tensorflow/python/keras/applications/mobilenet.py
+++ b/tensorflow/python/keras/applications/mobilenet.py
@@ -72,13 +72,9 @@ from __future__ import print_function
import os
from tensorflow.python.keras import backend as K
-from tensorflow.python.keras import constraints
-from tensorflow.python.keras import initializers
-from tensorflow.python.keras import regularizers
from tensorflow.python.keras.applications import imagenet_utils
from tensorflow.python.keras.applications.imagenet_utils import _obtain_input_shape
from tensorflow.python.keras.applications.imagenet_utils import decode_predictions
-from tensorflow.python.keras.engine.base_layer import InputSpec
from tensorflow.python.keras.layers import Activation
from tensorflow.python.keras.layers import BatchNormalization
from tensorflow.python.keras.layers import Conv2D
@@ -87,10 +83,10 @@ from tensorflow.python.keras.layers import Dropout
from tensorflow.python.keras.layers import GlobalAveragePooling2D
from tensorflow.python.keras.layers import GlobalMaxPooling2D
from tensorflow.python.keras.layers import Input
+from tensorflow.python.keras.layers import ReLU
from tensorflow.python.keras.layers import Reshape
from tensorflow.python.keras.layers import ZeroPadding2D
from tensorflow.python.keras.models import Model
-from tensorflow.python.keras.utils import conv_utils
from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.keras.utils.data_utils import get_file
from tensorflow.python.platform import tf_logging as logging
@@ -100,10 +96,6 @@ from tensorflow.python.util.tf_export import tf_export
BASE_WEIGHT_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.6/'
-def relu6(x):
- return K.relu(x, max_value=6)
-
-
@tf_export('keras.applications.mobilenet.preprocess_input')
def preprocess_input(x):
"""Preprocesses a numpy array encoding a batch of images.
@@ -130,12 +122,6 @@ def MobileNet(input_shape=None,
classes=1000):
"""Instantiates the MobileNet architecture.
- To load a MobileNet model via `load_model`, import the custom
- objects `relu6` and pass them to the `custom_objects` parameter.
- E.g.
- model = load_model('mobilenet.h5', custom_objects={
- 'relu6': mobilenet.relu6})
-
Arguments:
input_shape: optional shape tuple, only to be specified
if `include_top` is False (otherwise the input shape
@@ -412,7 +398,7 @@ def _conv_block(inputs, filters, alpha, kernel=(3, 3), strides=(1, 1)):
strides=strides,
name='conv1')(x)
x = BatchNormalization(axis=channel_axis, name='conv1_bn')(x)
- return Activation(relu6, name='conv1_relu')(x)
+ return ReLU(6, name='conv1_relu')(x)
def _depthwise_conv_block(inputs,
@@ -479,7 +465,7 @@ def _depthwise_conv_block(inputs,
use_bias=False,
name='conv_dw_%d' % block_id)(x)
x = BatchNormalization(axis=channel_axis, name='conv_dw_%d_bn' % block_id)(x)
- x = Activation(relu6, name='conv_dw_%d_relu' % block_id)(x)
+ x = ReLU(6, name='conv_dw_%d_relu' % block_id)(x)
x = Conv2D(
pointwise_conv_filters, (1, 1),
@@ -489,4 +475,4 @@ def _depthwise_conv_block(inputs,
name='conv_pw_%d' % block_id)(
x)
x = BatchNormalization(axis=channel_axis, name='conv_pw_%d_bn' % block_id)(x)
- return Activation(relu6, name='conv_pw_%d_relu' % block_id)(x)
+ return ReLU(6, name='conv_pw_%d_relu' % block_id)(x)
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index 824513dce0..cb3423598b 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -963,13 +963,14 @@ def zeros(shape, dtype=None, name=None):
[ 0., 0., 0., 0.]], dtype=float32)
```
"""
- if dtype is None:
- dtype = floatx()
- tf_dtype = dtypes_module.as_dtype(dtype)
- v = array_ops.zeros(shape=shape, dtype=tf_dtype, name=name)
- if py_all(v.get_shape().as_list()):
- return variable(v, dtype=dtype, name=name)
- return v
+ with ops.init_scope():
+ if dtype is None:
+ dtype = floatx()
+ tf_dtype = dtypes_module.as_dtype(dtype)
+ v = array_ops.zeros(shape=shape, dtype=tf_dtype, name=name)
+ if py_all(v.get_shape().as_list()):
+ return variable(v, dtype=dtype, name=name)
+ return v
@tf_export('keras.backend.ones')
@@ -996,13 +997,14 @@ def ones(shape, dtype=None, name=None):
[ 1., 1., 1., 1.]], dtype=float32)
```
"""
- if dtype is None:
- dtype = floatx()
- tf_dtype = dtypes_module.as_dtype(dtype)
- v = array_ops.ones(shape=shape, dtype=tf_dtype, name=name)
- if py_all(v.get_shape().as_list()):
- return variable(v, dtype=dtype, name=name)
- return v
+ with ops.init_scope():
+ if dtype is None:
+ dtype = floatx()
+ tf_dtype = dtypes_module.as_dtype(dtype)
+ v = array_ops.ones(shape=shape, dtype=tf_dtype, name=name)
+ if py_all(v.get_shape().as_list()):
+ return variable(v, dtype=dtype, name=name)
+ return v
@tf_export('keras.backend.eye')
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py
index 3ae06d7ab8..53d907a2cc 100644
--- a/tensorflow/python/keras/callbacks.py
+++ b/tensorflow/python/keras/callbacks.py
@@ -733,6 +733,7 @@ class TensorBoard(Callback):
self.model = model
self.sess = K.get_session()
+ # only make histogram summary op if it hasn't already been made
if self.histogram_freq and self.merged is None:
for layer in self.model.layers:
for weight in layer.weights:
@@ -787,20 +788,34 @@ class TensorBoard(Callback):
def _fetch_callback(self, summary):
self.writer.add_summary(
- summary, self._epoch + self._current_batch / self._batches_per_epoch)
- self._current_batch += 1
+ summary,
+ self._epoch + self._current_val_batch / self._validation_batches)
+ self._current_val_batch += 1
+
+ def on_train_begin(self, logs=None):
+ """Checks if histogram summaries can be run."""
+
+ if self.histogram_freq:
+ if 'validation_steps' in self.params:
+ self._validation_batches = self.params['validation_steps']
+ elif self.validation_data:
+ self._validation_batches = math.ceil(
+ self.validation_data[0].shape[0] / self.batch_size)
+ else:
+ raise ValueError('If printing histograms, validation data must be '
+ 'provided.')
+ if self._validation_batches == 0:
+ raise ValueError(
+ 'If printing histograms, validation data must have length > 0.')
def on_epoch_begin(self, epoch, logs=None):
"""Add histogram op to Model test_function callbacks, reset batch count."""
- if not self.validation_data and self.histogram_freq:
- raise ValueError('If printing histograms, validation_data must be '
- 'provided, and cannot be a generator.')
+ # check if histogram summary should be run for this epoch
if self.histogram_freq and epoch % self.histogram_freq == 0:
self._epoch = epoch
- self._current_batch = 0
- self._batches_per_epoch = math.ceil(
- self.validation_data[0].shape[0] / self.batch_size)
+ self._current_val_batch = 0
+ # add the histogram summary op if it should run this epoch
if self.merged not in self.model.test_function.fetches:
self.model.test_function.fetches.append(self.merged)
self.model.test_function.fetch_callbacks[
@@ -811,7 +826,8 @@ class TensorBoard(Callback):
logs = logs or {}
- if self.histogram_freq and self.histogram_freq > 1:
+ # pop the histogram summary op after each epoch
+ if self.histogram_freq:
if self.merged in self.model.test_function.fetches:
self.model.test_function.fetches.remove(self.merged)
if self.merged in self.model.test_function.fetch_callbacks:
diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py
index d56f2f5bfc..45598cafd3 100644
--- a/tensorflow/python/keras/callbacks_test.py
+++ b/tensorflow/python/keras/callbacks_test.py
@@ -813,21 +813,6 @@ class KerasCallbacksTest(test.TestCase):
for cb in cbs:
cb.on_train_end()
- # fit generator with validation data generator should raise ValueError if
- # histogram_freq > 0
- cbs = callbacks_factory(histogram_freq=1)
- with self.assertRaises(ValueError):
- model.fit_generator(
- data_generator(True),
- len(x_train),
- epochs=2,
- validation_data=data_generator(False),
- validation_steps=1,
- callbacks=cbs)
-
- for cb in cbs:
- cb.on_train_end()
-
# Make sure file writer cache is clear to avoid failures during cleanup.
writer_cache.FileWriterCache.clear()
@@ -976,6 +961,56 @@ class KerasCallbacksTest(test.TestCase):
self.assertAllEqual(tsb.writer.steps_seen, [0, 0.5, 1, 1.5, 2, 2.5])
+ def test_Tensorboard_histogram_summaries_with_generator(self):
+ np.random.seed(1337)
+ tmpdir = self.get_temp_dir()
+ self.addCleanup(shutil.rmtree, tmpdir)
+
+ def generator():
+ x = np.random.randn(10, 100).astype(np.float32)
+ y = np.random.randn(10, 10).astype(np.float32)
+ while True:
+ yield x, y
+
+ with self.test_session():
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(10, input_dim=100, activation='relu'))
+ model.add(keras.layers.Dense(10, activation='softmax'))
+ model.compile(
+ loss='categorical_crossentropy',
+ optimizer='sgd',
+ metrics=['accuracy'])
+ tsb = keras.callbacks.TensorBoard(
+ log_dir=tmpdir,
+ histogram_freq=1,
+ write_images=True,
+ write_grads=True,
+ batch_size=5)
+ cbks = [tsb]
+
+ # fit with validation generator
+ model.fit_generator(
+ generator(),
+ steps_per_epoch=2,
+ epochs=2,
+ validation_data=generator(),
+ validation_steps=2,
+ callbacks=cbks,
+ verbose=0)
+
+ with self.assertRaises(ValueError):
+ # fit with validation generator but no
+ # validation_steps
+ model.fit_generator(
+ generator(),
+ steps_per_epoch=2,
+ epochs=2,
+ validation_data=generator(),
+ callbacks=cbks,
+ verbose=0)
+
+ self.assertTrue(os.path.exists(tmpdir))
+
@unittest.skipIf(
os.name == 'nt',
'use_multiprocessing=True does not work on windows properly.')
diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py
index 361778570b..e02792208b 100644
--- a/tensorflow/python/keras/engine/base_layer.py
+++ b/tensorflow/python/keras/engine/base_layer.py
@@ -460,14 +460,18 @@ class Layer(checkpointable.CheckpointableBase):
"""Alias for `add_weight`."""
return self.add_weight(*args, **kwargs)
- def add_weight(self, name, shape,
+ def add_weight(self,
+ name,
+ shape,
dtype=None,
initializer=None,
regularizer=None,
- trainable=True,
+ trainable=None,
constraint=None,
partitioner=None,
use_resource=None,
+ synchronization=vs.VariableSynchronization.AUTO,
+ aggregation=vs.VariableAggregation.NONE,
getter=None):
"""Adds a new variable to the layer, or gets an existing one; returns it.
@@ -482,10 +486,20 @@ class Layer(checkpointable.CheckpointableBase):
or "non_trainable_variables" (e.g. BatchNorm mean, stddev).
Note, if the current variable scope is marked as non-trainable
then this parameter is ignored and any added variables are also
- marked as non-trainable.
+ marked as non-trainable. `trainable` defaults to `True` unless
+ `synchronization` is set to `ON_READ`.
constraint: constraint instance (callable).
partitioner: Partitioner to be passed to the `Checkpointable` API.
use_resource: Whether to use `ResourceVariable`.
+ synchronization: Indicates when a distributed a variable will be
+ aggregated. Accepted values are constants defined in the class
+ @{tf.VariableSynchronization}. By default the synchronization is set to
+ `AUTO` and the current `DistributionStrategy` chooses
+ when to synchronize. If `synchronization` is set to `ON_READ`,
+ `trainable` must not be set to `True`.
+ aggregation: Indicates how a distributed variable will be aggregated.
+ Accepted values are constants defined in the class
+ @{tf.VariableAggregation}.
getter: Variable getter argument to be passed to the `Checkpointable` API.
Returns:
@@ -496,7 +510,8 @@ class Layer(checkpointable.CheckpointableBase):
Raises:
RuntimeError: If called with partioned variable regularization and
eager execution is enabled.
- ValueError: When giving unsupported dtype and no initializer.
+ ValueError: When giving unsupported dtype and no initializer or when
+ trainable has been set to True with synchronization set as `ON_READ`.
"""
if dtype is None:
dtype = self.dtype or backend.floatx()
@@ -505,6 +520,19 @@ class Layer(checkpointable.CheckpointableBase):
regularizer = regularizers.get(regularizer)
constraint = constraints.get(constraint)
+ if synchronization == vs.VariableSynchronization.ON_READ:
+ if trainable:
+ raise ValueError(
+ 'Synchronization value can be set to '
+ 'VariableSynchronization.ON_READ only for non-trainable variables. '
+ 'You have specified trainable=True and '
+ 'synchronization=VariableSynchronization.ON_READ.')
+ else:
+ # Set trainable to be false when variable is to be synced on read.
+ trainable = False
+ elif trainable is None:
+ trainable = True
+
# Initialize variable when no initializer provided
if initializer is None:
# If dtype is DT_FLOAT, provide a uniform unit scaling initializer
@@ -532,7 +560,9 @@ class Layer(checkpointable.CheckpointableBase):
constraint=constraint,
trainable=trainable and self.trainable,
partitioner=partitioner,
- use_resource=use_resource)
+ use_resource=use_resource,
+ synchronization=synchronization,
+ aggregation=aggregation)
if regularizer is not None:
# TODO(fchollet): in the future, this should be handled at the
@@ -655,8 +685,8 @@ class Layer(checkpointable.CheckpointableBase):
# Handle Keras mask propagation from previous layer to current layer.
previous_mask = None
- if (not hasattr(self, '_compute_previous_mask') or
- self._compute_previous_mask):
+ if build_graph and (not hasattr(self, '_compute_previous_mask') or
+ self._compute_previous_mask):
previous_mask = collect_previous_mask(inputs)
if not hasattr(self, '_call_fn_args'):
self._call_fn_args = self._no_dependency(
@@ -693,9 +723,10 @@ class Layer(checkpointable.CheckpointableBase):
self._dtype = input_list[0].dtype.base_dtype.name
except AttributeError:
pass
- if all(hasattr(x, 'get_shape') for x in input_list):
- input_shapes = nest.map_structure(lambda x: x.get_shape(), inputs)
+ if all(hasattr(x, 'shape') for x in input_list):
+ input_shapes = nest.map_structure(lambda x: x.shape, inputs)
self.build(input_shapes)
+ self.built = True
# Check input assumptions set after layer building, e.g. input shape.
if build_graph or in_deferred_mode:
@@ -711,7 +742,7 @@ class Layer(checkpointable.CheckpointableBase):
# Deferred mode behavior: use `compute_output_shape` to
# infer the number of outputs of the layer and their shapes.
if input_shapes is None:
- input_shapes = nest.map_structure(lambda x: x.get_shape(), inputs)
+ input_shapes = nest.map_structure(lambda x: x.shape, inputs)
output_shapes = self.compute_output_shape(input_shapes)
output_shapes = nest.flatten(output_shapes)
@@ -731,8 +762,6 @@ class Layer(checkpointable.CheckpointableBase):
if in_deferred_mode or build_graph and have_all_keras_metadata(inputs):
inputs, outputs = self._set_connectivity_metadata_(
inputs, outputs, args, kwargs)
-
- self.built = True
if context.executing_eagerly():
return outputs
@@ -1295,7 +1324,7 @@ class Layer(checkpointable.CheckpointableBase):
', but the layer isn\'t built. '
'You can build it manually via: `' + self.name +
'.build(batch_input_shape)`.')
- weight_shapes = [w.get_shape().as_list() for w in self.weights]
+ weight_shapes = [w.shape.as_list() for w in self.weights]
return int(sum([np.prod(w) for w in weight_shapes]))
@property
@@ -1378,7 +1407,7 @@ class Layer(checkpointable.CheckpointableBase):
if (spec.ndim is not None or
spec.min_ndim is not None or
spec.max_ndim is not None):
- if x.get_shape().ndims is None:
+ if x.shape.ndims is None:
raise ValueError('Input ' + str(input_index) + ' of layer ' +
self.name + ' is incompatible with the layer: '
'its rank is undefined, but the layer requires a '
@@ -1386,29 +1415,29 @@ class Layer(checkpointable.CheckpointableBase):
# Check ndim.
if spec.ndim is not None:
- ndim = x.get_shape().ndims
+ ndim = x.shape.ndims
if ndim != spec.ndim:
raise ValueError('Input ' + str(input_index) + ' of layer ' +
self.name + ' is incompatible with the layer: '
'expected ndim=' + str(spec.ndim) + ', found ndim=' +
str(ndim) + '. Full shape received: ' +
- str(x.get_shape().as_list()))
+ str(x.shape.as_list()))
if spec.max_ndim is not None:
- ndim = x.get_shape().ndims
+ ndim = x.shape.ndims
if ndim is not None and ndim > spec.max_ndim:
raise ValueError('Input ' + str(input_index) + ' of layer ' +
self.name + ' is incompatible with the layer: '
'expected max_ndim=' + str(spec.max_ndim) +
', found ndim=' + str(ndim))
if spec.min_ndim is not None:
- ndim = x.get_shape().ndims
+ ndim = x.shape.ndims
if ndim is not None and ndim < spec.min_ndim:
raise ValueError('Input ' + str(input_index) + ' of layer ' +
self.name + ' is incompatible with the layer: '
': expected min_ndim=' + str(spec.min_ndim) +
', found ndim=' + str(ndim) +
'. Full shape received: ' +
- str(x.get_shape().as_list()))
+ str(x.shape.as_list()))
# Check dtype.
if spec.dtype is not None:
if x.dtype != spec.dtype:
@@ -1418,7 +1447,7 @@ class Layer(checkpointable.CheckpointableBase):
', found dtype=' + str(x.dtype))
# Check specific shape axes.
if spec.axes:
- shape = x.get_shape().as_list()
+ shape = x.shape.as_list()
if shape is not None:
for axis, value in spec.axes.items():
if hasattr(value, 'value'):
@@ -1431,7 +1460,7 @@ class Layer(checkpointable.CheckpointableBase):
' but received input with shape ' + str(shape))
# Check shape.
if spec.shape is not None:
- shape = x.get_shape().as_list()
+ shape = x.shape.as_list()
if shape is not None:
for spec_dim, dim in zip(spec.shape, shape):
if spec_dim is not None and dim is not None:
@@ -1706,12 +1735,12 @@ class DeferredTensor(object):
def __str__(self):
return "DeferredTensor('%s', shape=%s, dtype=%s)" % (self.name,
- self.get_shape(),
+ self.shape,
self.dtype.name)
def __repr__(self):
return "<DeferredTensor '%s' shape=%s dtype=%s>" % (self.name,
- self.get_shape(),
+ self.shape,
self.dtype.name)
@@ -1806,11 +1835,13 @@ def make_variable(name,
dtype=dtypes.float32,
initializer=None,
partition_info=None,
- trainable=True,
+ trainable=None,
caching_device=None,
validate_shape=True,
constraint=None,
use_resource=None,
+ synchronization=vs.VariableSynchronization.AUTO,
+ aggregation=vs.VariableAggregation.NONE,
partitioner=None): # pylint: disable=unused-argument
"""Temporary util to create a variable (relies on `variable_scope.variable`).
@@ -1836,11 +1867,21 @@ def make_variable(name,
or "non_trainable_variables" (e.g. BatchNorm mean, stddev).
Note, if the current variable scope is marked as non-trainable
then this parameter is ignored and any added variables are also
- marked as non-trainable.
+ marked as non-trainable. `trainable` defaults to `True` unless
+ `synchronization` is set to `ON_READ`.
caching_device: Passed to `vs.variable`.
validate_shape: Passed to `vs.variable`.
constraint: Constraint instance (callable).
use_resource: Whether to use a `ResourceVariable`.
+ synchronization: Indicates when a distributed a variable will be
+ aggregated. Accepted values are constants defined in the class
+ @{tf.VariableSynchronization}. By default the synchronization is set to
+ `AUTO` and the current `DistributionStrategy` chooses
+ when to synchronize. If `synchronization` is set to `ON_READ`,
+ `trainable` must not be set to `True`.
+ aggregation: Indicates how a distributed variable will be aggregated.
+ Accepted values are constants defined in the class
+ @{tf.VariableAggregation}.
partitioner: Not handled at this time.
Returns:
@@ -1872,5 +1913,7 @@ def make_variable(name,
dtype=variable_dtype,
validate_shape=validate_shape,
constraint=constraint,
- use_resource=use_resource)
+ use_resource=use_resource,
+ synchronization=synchronization,
+ aggregation=aggregation)
return v
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index 8e632651fa..bd03f4871f 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -599,7 +599,7 @@ class Model(Network):
# Unconditional updates
updates += self.get_updates_for(None)
# Conditional updates relevant to this model
- updates += self.get_updates_for(self._feed_inputs)
+ updates += self.get_updates_for(self.inputs)
# Stateful metrics updates
updates += self.metrics_updates
# Gets loss and metrics. Updates weights at each call.
diff --git a/tensorflow/python/keras/engine/training_arrays.py b/tensorflow/python/keras/engine/training_arrays.py
index e82f5c0332..adefffab11 100644
--- a/tensorflow/python/keras/engine/training_arrays.py
+++ b/tensorflow/python/keras/engine/training_arrays.py
@@ -124,12 +124,10 @@ def fit_loop(model,
callback_metrics = copy.copy(out_labels) + [
'val_' + n for n in out_labels
]
- if callbacks is not None and any(
- [isinstance(callback, cbks.TensorBoard) for callback in callbacks]):
- # need to create the test_function before start of the first epoch
- # because TensorBoard callback on_epoch_begin adds summary to the
- # list of fetches of the test_function
- model._make_test_function()
+ # need to create the test_function before start of the first epoch
+ # because TensorBoard callback on_epoch_begin adds summary to the
+ # list of fetches of the test_function
+ model._make_test_function()
else:
callback_metrics = copy.copy(out_labels)
@@ -162,7 +160,7 @@ def fit_loop(model,
callbacks.set_model(callback_model)
- callbacks.set_params({
+ callback_params = {
'batch_size': batch_size,
'epochs': epochs,
'steps': steps_per_epoch,
@@ -170,11 +168,17 @@ def fit_loop(model,
'verbose': verbose,
'do_validation': do_validation,
'metrics': callback_metrics or [],
- })
- callbacks.on_train_begin()
- callback_model.stop_training = False
+ }
+ if validation_steps:
+ callback_params.update({'validation_steps': validation_steps})
+ callbacks.set_params(callback_params)
+
for cbk in callbacks:
cbk.validation_data = val_ins
+ # validation_data must be set before on_train_begin() is called
+ # so that TensorboardCallback can validate its input
+ callbacks.on_train_begin()
+ callback_model.stop_training = False
# To prevent a slowdown, we find beforehand the arrays that need conversion.
feed = model._feed_inputs + model._feed_targets + model._feed_sample_weights
diff --git a/tensorflow/python/keras/engine/training_eager.py b/tensorflow/python/keras/engine/training_eager.py
index e8838cd3bc..c78684c9f4 100644
--- a/tensorflow/python/keras/engine/training_eager.py
+++ b/tensorflow/python/keras/engine/training_eager.py
@@ -989,7 +989,7 @@ def fit_loop(model,
callbacks.set_model(callback_model)
- callbacks.set_params({
+ callback_params = {
'batch_size': batch_size,
'epochs': epochs,
'steps': steps_per_epoch,
@@ -997,9 +997,11 @@ def fit_loop(model,
'verbose': verbose,
'do_validation': do_validation,
'metrics': callback_metrics or [],
- })
- callbacks.on_train_begin()
- callback_model.stop_training = False
+ }
+ if validation_steps:
+ callback_params.update({'validation_steps': validation_steps})
+ callbacks.set_params(callback_params)
+
for cbk in callbacks:
if not val_inputs:
cbk.validation_data = []
@@ -1009,6 +1011,10 @@ def fit_loop(model,
cbk.validation_data = val_inputs + val_targets + val_sample_weights
else:
cbk.validation_data = val_inputs + val_targets
+ # validation_data must be set before on_train_begin() is called
+ # so that TensorboardCallback can validate its input
+ callbacks.on_train_begin()
+ callback_model.stop_training = False
for epoch in range(initial_epoch, epochs):
callbacks.on_epoch_begin(epoch)
diff --git a/tensorflow/python/keras/engine/training_generator.py b/tensorflow/python/keras/engine/training_generator.py
index d81b384f0e..432cf2bddd 100644
--- a/tensorflow/python/keras/engine/training_generator.py
+++ b/tensorflow/python/keras/engine/training_generator.py
@@ -96,14 +96,25 @@ def fit_generator(model,
else:
callback_model = model
callbacks.set_model(callback_model)
- callbacks.set_params({
+
+ callback_params = {
'epochs': epochs,
'steps': steps_per_epoch,
'verbose': verbose,
'do_validation': do_validation,
'metrics': callback_metrics,
- })
- callbacks.on_train_begin()
+ }
+ if do_validation:
+ # need to create the test_function before start of the first epoch
+ # because TensorBoard callback on_epoch_begin adds summary to the
+ # list of fetches of the test_function
+ model._make_test_function()
+ # determine the number of validation batches given a generator
+ if validation_steps:
+ callback_params.update({'validation_steps': validation_steps})
+ elif isinstance(validation_data, Sequence):
+ callback_params.update({'validation_steps': len(validation_data)})
+ callbacks.set_params(callback_params)
enqueuer = None
val_enqueuer = None
@@ -149,6 +160,9 @@ def fit_generator(model,
output_generator = generator
callback_model.stop_training = False
+ # validation_data must be set before on_train_begin() is called
+ # so that TensorboardCallback can validate its input
+ callbacks.on_train_begin()
# Construct epoch logs.
epoch_logs = {}
while epoch < epochs:
diff --git a/tensorflow/python/keras/initializers.py b/tensorflow/python/keras/initializers.py
index b9b2e9ad59..28beb6760d 100644
--- a/tensorflow/python/keras/initializers.py
+++ b/tensorflow/python/keras/initializers.py
@@ -23,6 +23,9 @@ import six
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
from tensorflow.python.ops.init_ops import Constant
+from tensorflow.python.ops.init_ops import glorot_normal_initializer
+from tensorflow.python.ops.init_ops import glorot_uniform_initializer
+
from tensorflow.python.ops.init_ops import Identity
from tensorflow.python.ops.init_ops import Initializer # pylint: disable=unused-import
from tensorflow.python.ops.init_ops import Ones
@@ -80,52 +83,6 @@ def lecun_uniform(seed=None):
scale=1., mode='fan_in', distribution='uniform', seed=seed)
-@tf_export('keras.initializers.glorot_normal')
-def glorot_normal(seed=None):
- """Glorot normal initializer, also called Xavier normal initializer.
-
- It draws samples from a truncated normal distribution centered on 0
- with `stddev = sqrt(2 / (fan_in + fan_out))`
- where `fan_in` is the number of input units in the weight tensor
- and `fan_out` is the number of output units in the weight tensor.
-
- Arguments:
- seed: A Python integer. Used to seed the random generator.
-
- Returns:
- An initializer.
-
- References:
- Glorot & Bengio, AISTATS 2010
- http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf
- """
- return VarianceScaling(
- scale=1., mode='fan_avg', distribution='normal', seed=seed)
-
-
-@tf_export('keras.initializers.glorot_uniform')
-def glorot_uniform(seed=None):
- """Glorot uniform initializer, also called Xavier uniform initializer.
-
- It draws samples from a uniform distribution within [-limit, limit]
- where `limit` is `sqrt(6 / (fan_in + fan_out))`
- where `fan_in` is the number of input units in the weight tensor
- and `fan_out` is the number of output units in the weight tensor.
-
- Arguments:
- seed: A Python integer. Used to seed the random generator.
-
- Returns:
- An initializer.
-
- References:
- Glorot & Bengio, AISTATS 2010
- http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf
- """
- return VarianceScaling(
- scale=1., mode='fan_avg', distribution='uniform', seed=seed)
-
-
@tf_export('keras.initializers.he_normal')
def he_normal(seed=None):
"""He normal initializer.
@@ -179,6 +136,8 @@ normal = random_normal = RandomNormal
truncated_normal = TruncatedNormal
identity = Identity
orthogonal = Orthogonal
+glorot_normal = glorot_normal_initializer
+glorot_uniform = glorot_uniform_initializer
# pylint: enable=invalid-name
diff --git a/tensorflow/python/keras/layers/core.py b/tensorflow/python/keras/layers/core.py
index 2bf6229ccb..f28cade474 100644
--- a/tensorflow/python/keras/layers/core.py
+++ b/tensorflow/python/keras/layers/core.py
@@ -26,6 +26,7 @@ import warnings
import numpy as np
from tensorflow.python.eager import context
+from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras import activations
@@ -929,13 +930,13 @@ class Dense(Layer):
def call(self, inputs):
inputs = ops.convert_to_tensor(inputs, dtype=self.dtype)
- shape = inputs.get_shape().as_list()
- if len(shape) > 2:
+ rank = common_shapes.rank(inputs)
+ if rank > 2:
# Broadcasting is required for the inputs.
- outputs = standard_ops.tensordot(inputs, self.kernel, [[len(shape) - 1],
- [0]])
+ outputs = standard_ops.tensordot(inputs, self.kernel, [[rank - 1], [0]])
# Reshape the output back to the original ndim of the input.
if not context.executing_eagerly():
+ shape = inputs.get_shape().as_list()
output_shape = shape[:-1] + [self.units]
outputs.set_shape(output_shape)
else:
diff --git a/tensorflow/python/keras/layers/embeddings.py b/tensorflow/python/keras/layers/embeddings.py
index 910fff720f..629a9ec9a1 100644
--- a/tensorflow/python/keras/layers/embeddings.py
+++ b/tensorflow/python/keras/layers/embeddings.py
@@ -112,6 +112,7 @@ class Embedding(Layer):
self.activity_regularizer = regularizers.get(activity_regularizer)
self.embeddings_constraint = constraints.get(embeddings_constraint)
self.mask_zero = mask_zero
+ self.supports_masking = mask_zero
self.input_length = input_length
@tf_utils.shape_type_conversion
@@ -127,8 +128,8 @@ class Embedding(Layer):
def compute_mask(self, inputs, mask=None):
if not self.mask_zero:
return None
- else:
- return math_ops.not_equal(inputs, 0)
+
+ return math_ops.not_equal(inputs, 0)
@tf_utils.shape_type_conversion
def compute_output_shape(self, input_shape):
diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py
index 8b894ca6b1..a7835bc0a2 100644
--- a/tensorflow/python/keras/layers/normalization.py
+++ b/tensorflow/python/keras/layers/normalization.py
@@ -181,12 +181,6 @@ class BatchNormalization(Layer):
self.renorm_clipping = renorm_clipping
self.renorm_momentum = renorm_momentum
- def _add_tower_local_variable(self, *args, **kwargs):
- tower_context = distribute_lib.get_tower_context()
- with tower_context.tower_local_var_scope(
- variable_scope.VariableAggregation.MEAN):
- return self.add_weight(*args, **kwargs)
-
def build(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape)
if not input_shape.ndims:
@@ -314,19 +308,23 @@ class BatchNormalization(Layer):
self._scope.set_partitioner(None)
else:
partitioner = None
- self.moving_mean = self._add_tower_local_variable(
+ self.moving_mean = self.add_weight(
name='moving_mean',
shape=param_shape,
dtype=param_dtype,
initializer=self.moving_mean_initializer,
- trainable=False)
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ trainable=False,
+ aggregation=variable_scope.VariableAggregation.MEAN)
- self.moving_variance = self._add_tower_local_variable(
+ self.moving_variance = self.add_weight(
name='moving_variance',
shape=param_shape,
dtype=param_dtype,
initializer=self.moving_variance_initializer,
- trainable=False)
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ trainable=False,
+ aggregation=variable_scope.VariableAggregation.MEAN)
if self.renorm:
# Create variables to maintain the moving mean and standard deviation.
@@ -337,12 +335,14 @@ class BatchNormalization(Layer):
# stack to be cleared. The nested ones use a `lambda` to set the desired
# device and ignore any devices that may be set by the custom getter.
def _renorm_variable(name, shape):
- var = self._add_tower_local_variable(
+ var = self.add_weight(
name=name,
shape=shape,
dtype=param_dtype,
initializer=init_ops.zeros_initializer(),
- trainable=False)
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ trainable=False,
+ aggregation=variable_scope.VariableAggregation.MEAN)
return var
with distribute_lib.get_distribution_strategy().colocate_vars_with(
@@ -370,7 +370,7 @@ class BatchNormalization(Layer):
decay = ops.convert_to_tensor(1.0 - momentum, name='decay')
if decay.dtype != variable.dtype.base_dtype:
decay = math_ops.cast(decay, variable.dtype.base_dtype)
- update_delta = (variable - value) * decay
+ update_delta = (variable - math_ops.cast(value, variable.dtype)) * decay
return state_ops.assign_sub(variable, update_delta, name=scope)
def _fused_batch_norm(self, inputs, training):
@@ -619,6 +619,10 @@ class BatchNormalization(Layer):
else:
mean, variance = self.moving_mean, self.moving_variance
+ mean = math_ops.cast(mean, inputs.dtype)
+ variance = math_ops.cast(variance, inputs.dtype)
+ if offset is not None:
+ offset = math_ops.cast(offset, inputs.dtype)
outputs = nn.batch_normalization(inputs,
_broadcast(mean),
_broadcast(variance),
diff --git a/tensorflow/python/keras/layers/normalization_test.py b/tensorflow/python/keras/layers/normalization_test.py
index b22f3bd152..a97b4cac46 100644
--- a/tensorflow/python/keras/layers/normalization_test.py
+++ b/tensorflow/python/keras/layers/normalization_test.py
@@ -95,6 +95,24 @@ class NormalizationLayersTest(test.TestCase):
np.testing.assert_allclose(out.mean(), 0.0, atol=1e-1)
np.testing.assert_allclose(out.std(), 1.0, atol=1e-1)
+ def test_batchnorm_mixed_precision(self):
+ with self.test_session():
+ model = keras.models.Sequential()
+ norm = keras.layers.BatchNormalization(input_shape=(10,), momentum=0.8)
+ model.add(norm)
+ model.compile(loss='mse', optimizer='sgd')
+
+ # centered on 5.0, variance 10.0
+ x = np.random.normal(
+ loc=5.0, scale=10.0, size=(1000, 10)).astype(np.float16)
+ model.fit(x, x, epochs=4, verbose=0)
+ out = model.predict(x)
+ out -= keras.backend.eval(norm.beta)
+ out /= keras.backend.eval(norm.gamma)
+
+ np.testing.assert_allclose(out.mean(), 0.0, atol=1e-1)
+ np.testing.assert_allclose(out.std(), 1.0, atol=1e-1)
+
def test_batchnorm_convnet(self):
if test.is_gpu_available(cuda_only=True):
with self.test_session(use_gpu=True):
diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py
index 32d25c5a65..61775da47b 100644
--- a/tensorflow/python/keras/layers/recurrent.py
+++ b/tensorflow/python/keras/layers/recurrent.py
@@ -235,7 +235,8 @@ class RNN(Layer):
"""Base class for recurrent layers.
Arguments:
- cell: A RNN cell instance. A RNN cell is a class that has:
+ cell: A RNN cell instance or a list of RNN cell instances.
+ A RNN cell is a class that has:
- a `call(input_at_t, states_at_t)` method, returning
`(output_at_t, states_at_t_plus_1)`. The call method of the
cell can also take the optional argument `constants`, see
@@ -248,9 +249,9 @@ class RNN(Layer):
(one size per state). In this case, the first entry
(`state_size[0]`) should be the same as
the size of the cell output.
- It is also possible for `cell` to be a list of RNN cell instances,
- in which cases the cells get stacked on after the other in the RNN,
- implementing an efficient stacked RNN.
+ In the case that `cell` is a list of RNN cell instances, the cells
+ will be stacked on after the other in the RNN, implementing an
+ efficient stacked RNN.
return_sequences: Boolean. Whether to return the last output
in the output sequence, or the full sequence.
return_state: Boolean. Whether to return the last state
diff --git a/tensorflow/python/keras/layers/wrappers.py b/tensorflow/python/keras/layers/wrappers.py
index e61acf8e77..f651e03874 100644
--- a/tensorflow/python/keras/layers/wrappers.py
+++ b/tensorflow/python/keras/layers/wrappers.py
@@ -169,6 +169,38 @@ class TimeDistributed(Wrapper):
super(TimeDistributed, self).__init__(layer, **kwargs)
self.supports_masking = True
+ def _get_shape_tuple(self, init_tuple, tensor, start_idx, int_shape=None):
+ """Finds non-specific dimensions in the static shapes.
+
+ The static shapes are replaced with the corresponding dynamic shapes of the
+ tensor.
+
+ Arguments:
+ init_tuple: a tuple, the first part of the output shape
+ tensor: the tensor from which to get the (static and dynamic) shapes
+ as the last part of the output shape
+ start_idx: int, which indicate the first dimension to take from
+ the static shape of the tensor
+ int_shape: an alternative static shape to take as the last part
+ of the output shape
+ Returns:
+ The new int_shape with the first part from init_tuple
+ and the last part from either `int_shape` (if provided)
+ or `tensor.shape`, where every `None` is replaced by
+ the corresponding dimension from `tf.shape(tensor)`.
+ """
+ # replace all None in int_shape by K.shape
+ if int_shape is None:
+ int_shape = K.int_shape(tensor)[start_idx:]
+ if not any(not s for s in int_shape):
+ return init_tuple + tuple(int_shape)
+ shape = K.shape(tensor)
+ int_shape = list(int_shape)
+ for i, s in enumerate(int_shape):
+ if not s:
+ int_shape[i] = shape[start_idx + i]
+ return init_tuple + tuple(int_shape)
+
def build(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape).as_list()
assert len(input_shape) >= 3
@@ -224,18 +256,24 @@ class TimeDistributed(Wrapper):
input_length = input_shape[1]
if not input_length:
input_length = array_ops.shape(inputs)[1]
+ inner_input_shape = self._get_shape_tuple((-1,), inputs, 2)
# Shape: (num_samples * timesteps, ...). And track the
# transformation in self._input_map.
input_uid = generic_utils.object_list_uid(inputs)
- inputs = array_ops.reshape(inputs, (-1,) + input_shape[2:])
+ inputs = array_ops.reshape(inputs, inner_input_shape)
self._input_map[input_uid] = inputs
# (num_samples * timesteps, ...)
+ if generic_utils.has_arg(self.layer.call, 'mask') and mask is not None:
+ inner_mask_shape = self._get_shape_tuple((-1,), mask, 2)
+ kwargs['mask'] = K.reshape(mask, inner_mask_shape)
y = self.layer.call(inputs, **kwargs)
if hasattr(y, '_uses_learning_phase'):
uses_learning_phase = y._uses_learning_phase
# Shape: (num_samples, timesteps, ...)
output_shape = self.compute_output_shape(input_shape).as_list()
- y = array_ops.reshape(y, (-1, input_length) + tuple(output_shape[2:]))
+ output_shape = self._get_shape_tuple(
+ (-1, input_length), y, 1, output_shape[2:])
+ y = array_ops.reshape(y, output_shape)
# Apply activity regularizer if any:
if (hasattr(self.layer, 'activity_regularizer') and
@@ -247,6 +285,80 @@ class TimeDistributed(Wrapper):
y._uses_learning_phase = True
return y
+ def compute_mask(self, inputs, mask=None):
+ """Computes an output mask tensor for Embedding layer.
+
+ This is based on the inputs, mask, and the inner layer.
+ If batch size is specified:
+ Simply return the input `mask`. (An rnn-based implementation with
+ more than one rnn inputs is required but not supported in tf.keras yet.)
+ Otherwise we call `compute_mask` of the inner layer at each time step.
+ If the output mask at each time step is not `None`:
+ (E.g., inner layer is Masking or RNN)
+ Concatenate all of them and return the concatenation.
+ If the output mask at each time step is `None` and the input mask is not
+ `None`:(E.g., inner layer is Dense)
+ Reduce the input_mask to 2 dimensions and return it.
+ Otherwise (both the output mask and the input mask are `None`):
+ (E.g., `mask` is not used at all)
+ Return `None`.
+
+ Arguments:
+ inputs: Tensor with shape [batch size, timesteps, ...] indicating the
+ input to TimeDistributed. If static shape information is available for
+ "batch size", `mask` is returned unmodified.
+ mask: Either None (indicating no masking) or a Tensor indicating the
+ input mask for TimeDistributed. The shape can be static or dynamic.
+
+ Returns:
+ Either None (no masking), or a [batch size, timesteps, ...] Tensor with
+ an output mask for the TimeDistributed layer with the shape beyond the
+ second dimension being the value of the input mask shape(if the computed
+ output mask is none), an output mask with the shape beyond the first
+ dimension being the value of the mask shape(if mask is not None) or
+ output mask with the shape beyond the first dimension being the
+ value of the computed output shape.
+
+ """
+ # cases need to call the layer.compute_mask when input_mask is None:
+ # Masking layer and Embedding layer with mask_zero
+ input_shape = K.int_shape(inputs)
+ if input_shape[0]:
+ # batch size matters, we currently do not handle mask explicitly
+ return mask
+ inner_mask = mask
+ if inner_mask is not None:
+ inner_mask_shape = self._get_shape_tuple((-1,), mask, 2)
+ inner_mask = K.reshape(inner_mask, inner_mask_shape)
+ input_uid = generic_utils.object_list_uid(inputs)
+ inner_inputs = self._input_map[input_uid]
+ output_mask = self.layer.compute_mask(inner_inputs, inner_mask)
+ if output_mask is None:
+ if mask is None:
+ return None
+ # input_mask is not None, and output_mask is None:
+ # we should return a not-None mask
+ output_mask = mask
+ for _ in range(2, len(K.int_shape(mask))):
+ output_mask = K.any(output_mask, axis=-1)
+ else:
+ # output_mask is not None. We need to reshape it
+ input_length = input_shape[1]
+ if not input_length:
+ input_length = K.shape(inputs)[1]
+ output_mask_int_shape = K.int_shape(output_mask)
+ if output_mask_int_shape is None:
+ # if the output_mask does not have a static shape,
+ # its shape must be the same as mask's
+ if mask is not None:
+ output_mask_int_shape = K.int_shape(mask)
+ else:
+ output_mask_int_shape = K.compute_output_shape(input_shape)[:-1]
+ output_mask_shape = self._get_shape_tuple(
+ (-1, input_length), output_mask, 1, output_mask_int_shape[1:])
+ output_mask = K.reshape(output_mask, output_mask_shape)
+ return output_mask
+
@tf_export('keras.layers.Bidirectional')
class Bidirectional(Wrapper):
diff --git a/tensorflow/python/keras/layers/wrappers_test.py b/tensorflow/python/keras/layers/wrappers_test.py
index c8f0d216e6..3f268acf5c 100644
--- a/tensorflow/python/keras/layers/wrappers_test.py
+++ b/tensorflow/python/keras/layers/wrappers_test.py
@@ -190,8 +190,8 @@ class TimeDistributedTest(test.TestCase):
x = keras.layers.Input(shape=(3, 2))
layer = keras.layers.TimeDistributed(keras.layers.BatchNormalization())
_ = layer(x)
- assert len(layer.updates) == 2
- assert len(layer.trainable_weights) == 2
+ self.assertEquals(len(layer.updates), 2)
+ self.assertEquals(len(layer.trainable_weights), 2)
layer.trainable = False
assert not layer.updates
assert not layer.trainable_weights
@@ -199,6 +199,62 @@ class TimeDistributedTest(test.TestCase):
assert len(layer.updates) == 2
assert len(layer.trainable_weights) == 2
+ def test_TimeDistributed_with_masked_embedding_and_unspecified_shape(self):
+ with self.test_session():
+ # test with unspecified shape and Embeddings with mask_zero
+ model = keras.models.Sequential()
+ model.add(keras.layers.TimeDistributed(
+ keras.layers.Embedding(5, 6, mask_zero=True),
+ input_shape=(None, None))) # N by t_1 by t_2 by 6
+ model.add(keras.layers.TimeDistributed(
+ keras.layers.SimpleRNN(7, return_sequences=True)))
+ model.add(keras.layers.TimeDistributed(
+ keras.layers.SimpleRNN(8, return_sequences=False)))
+ model.add(keras.layers.SimpleRNN(1, return_sequences=False))
+ model.compile(optimizer='rmsprop', loss='mse')
+ model_input = np.random.randint(low=1, high=5, size=(10, 3, 4),
+ dtype='int32')
+ for i in range(4):
+ model_input[i, i:, i:] = 0
+ model.fit(model_input,
+ np.random.random((10, 1)), epochs=1, batch_size=10)
+ mask_outputs = [model.layers[0].compute_mask(model.input)]
+ for layer in model.layers[1:]:
+ mask_outputs.append(layer.compute_mask(layer.input, mask_outputs[-1]))
+ func = keras.backend.function([model.input], mask_outputs[:-1])
+ mask_outputs_val = func([model_input])
+ ref_mask_val_0 = model_input > 0 # embedding layer
+ ref_mask_val_1 = ref_mask_val_0 # first RNN layer
+ ref_mask_val_2 = np.any(ref_mask_val_1, axis=-1) # second RNN layer
+ ref_mask_val = [ref_mask_val_0, ref_mask_val_1, ref_mask_val_2]
+ for i in range(3):
+ self.assertAllEqual(mask_outputs_val[i], ref_mask_val[i])
+ self.assertIs(mask_outputs[-1], None) # final layer
+
+ def test_TimeDistributed_with_masking_layer(self):
+ with self.test_session():
+ # test with Masking layer
+ model = keras.models.Sequential()
+ model.add(keras.layers.TimeDistributed(keras.layers.Masking(
+ mask_value=0.,), input_shape=(None, 4)))
+ model.add(keras.layers.TimeDistributed(keras.layers.Dense(5)))
+ model.compile(optimizer='rmsprop', loss='mse')
+ model_input = np.random.randint(low=1, high=5, size=(10, 3, 4))
+ for i in range(4):
+ model_input[i, i:, :] = 0.
+ model.compile(optimizer='rmsprop', loss='mse')
+ model.fit(model_input,
+ np.random.random((10, 3, 5)), epochs=1, batch_size=6)
+ mask_outputs = [model.layers[0].compute_mask(model.input)]
+ mask_outputs += [model.layers[1].compute_mask(model.layers[1].input,
+ mask_outputs[-1])]
+ func = keras.backend.function([model.input], mask_outputs)
+ mask_outputs_val = func([model_input])
+ self.assertEqual((mask_outputs_val[0]).all(),
+ model_input.all())
+ self.assertEqual((mask_outputs_val[1]).all(),
+ model_input.all())
+
class BidirectionalTest(test.TestCase):
diff --git a/tensorflow/python/keras/models_test.py b/tensorflow/python/keras/models_test.py
index ad3819e6e7..1525104ac9 100644
--- a/tensorflow/python/keras/models_test.py
+++ b/tensorflow/python/keras/models_test.py
@@ -37,6 +37,7 @@ class TestModelCloning(test.TestCase):
model = keras.models.Sequential()
model.add(keras.layers.Dense(4, input_shape=(4,)))
+ model.add(keras.layers.BatchNormalization())
model.add(keras.layers.Dropout(0.5))
model.add(keras.layers.Dense(4))
@@ -46,6 +47,8 @@ class TestModelCloning(test.TestCase):
with self.test_session():
# With placeholder creation
new_model = keras.models.clone_model(model)
+ # update ops from batch norm needs to be included
+ self.assertEquals(len(new_model.get_updates_for(new_model.inputs)), 2)
new_model.compile('rmsprop', 'mse')
new_model.train_on_batch(val_a, val_out)
@@ -53,6 +56,7 @@ class TestModelCloning(test.TestCase):
input_a = keras.Input(shape=(4,))
new_model = keras.models.clone_model(
model, input_tensors=input_a)
+ self.assertEquals(len(new_model.get_updates_for(new_model.inputs)), 2)
new_model.compile('rmsprop', 'mse')
new_model.train_on_batch(val_a, val_out)
@@ -60,6 +64,7 @@ class TestModelCloning(test.TestCase):
input_a = keras.backend.variable(val_a)
new_model = keras.models.clone_model(
model, input_tensors=input_a)
+ self.assertEquals(len(new_model.get_updates_for(new_model.inputs)), 2)
new_model.compile('rmsprop', 'mse')
new_model.train_on_batch(None, val_out)
@@ -76,6 +81,7 @@ class TestModelCloning(test.TestCase):
x_a = dense_1(input_a)
x_a = keras.layers.Dropout(0.5)(x_a)
+ x_a = keras.layers.BatchNormalization()(x_a)
x_b = dense_1(input_b)
x_a = dense_2(x_a)
outputs = keras.layers.add([x_a, x_b])
@@ -87,6 +93,7 @@ class TestModelCloning(test.TestCase):
with self.test_session():
# With placeholder creation
new_model = keras.models.clone_model(model)
+ self.assertEquals(len(new_model.get_updates_for(new_model.inputs)), 2)
new_model.compile('rmsprop', 'mse')
new_model.train_on_batch([val_a, val_b], val_out)
@@ -95,6 +102,7 @@ class TestModelCloning(test.TestCase):
input_b = keras.Input(shape=(4,), name='b')
new_model = keras.models.clone_model(
model, input_tensors=[input_a, input_b])
+ self.assertEquals(len(new_model.get_updates_for(new_model.inputs)), 2)
new_model.compile('rmsprop', 'mse')
new_model.train_on_batch([val_a, val_b], val_out)
@@ -103,6 +111,7 @@ class TestModelCloning(test.TestCase):
input_b = keras.backend.variable(val_b)
new_model = keras.models.clone_model(
model, input_tensors=[input_a, input_b])
+ self.assertEquals(len(new_model.get_updates_for(new_model.inputs)), 2)
new_model.compile('rmsprop', 'mse')
new_model.train_on_batch(None, val_out)
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 6bfd1936e3..838cf836f1 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -1525,6 +1525,7 @@ cuda_py_test(
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
],
+ tags = ["no_windows_gpu"],
)
cuda_py_test(
@@ -2057,6 +2058,7 @@ cuda_py_test(
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
],
+ tags = ["no_windows_gpu"],
)
tf_py_test(
@@ -2843,6 +2845,7 @@ cuda_py_test(
"//tensorflow/python:math_ops",
],
shard_count = 20,
+ tags = ["nomsan"], # TODO(b/110990716) reenable
)
cuda_py_test(
diff --git a/tensorflow/python/kernel_tests/functional_ops_test.py b/tensorflow/python/kernel_tests/functional_ops_test.py
index 5272a3631f..24800d2b7a 100644
--- a/tensorflow/python/kernel_tests/functional_ops_test.py
+++ b/tensorflow/python/kernel_tests/functional_ops_test.py
@@ -1097,10 +1097,8 @@ class PartitionedCallTest(test.TestCase):
self.assertEqual(value, 2.0)
def testFunctionWithResourcesOnDifferentDevices(self):
- # TODO(akshayka): Remove the `skipTest` once we can whitelist ops as
- # safe to be invoked with resources on different devices.
- self.skipTest("The Placer disallows ops with resource inputs "
- "on different devices.")
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPUs available.")
with ops.device("/cpu:0"):
v_cpu_zero = resource_variable_ops.ResourceVariable(
diff --git a/tensorflow/python/kernel_tests/init_ops_test.py b/tensorflow/python/kernel_tests/init_ops_test.py
index 927ca012ae..f6097ad489 100644
--- a/tensorflow/python/kernel_tests/init_ops_test.py
+++ b/tensorflow/python/kernel_tests/init_ops_test.py
@@ -830,7 +830,7 @@ class ConvolutionOrthogonal1dInitializerTest(test.TestCase):
tol = 1e-3
gain = 3.14
# Check orthogonality/isometry by computing the ratio between
- # the 2-norms of the inputs and ouputs.
+ # the 2-norms of the inputs and outputs.
for kernel_size in [[1], [2], [3], [4], [5], [6]]:
convolution = convolutional.conv1d
inputs = random_ops.random_normal(shape, dtype=dtype)
@@ -925,7 +925,7 @@ class ConvolutionOrthogonal2dInitializerTest(test.TestCase):
tol = 1e-3
gain = 3.14
# Check orthogonality/isometry by computing the ratio between
- # the 2-norms of the inputs and ouputs.
+ # the 2-norms of the inputs and outputs.
for kernel_size in [[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]]:
convolution = convolutional.conv2d
inputs = random_ops.random_normal(shape, dtype=dtype)
@@ -1050,7 +1050,7 @@ class ConvolutionOrthogonal3dInitializerTest(test.TestCase):
tol = 1e-3
gain = 3.14
# Check orthogonality/isometry by computing the ratio between
- # the 2-norms of the inputs and ouputs.
+ # the 2-norms of the inputs and outputs.
for kernel_size in [[1, 1, 1], [2, 2, 2], [3, 3, 3]]:
convolution = convolutional.conv3d
inputs = random_ops.random_normal(shape, dtype=dtype)
diff --git a/tensorflow/python/kernel_tests/linalg/BUILD b/tensorflow/python/kernel_tests/linalg/BUILD
index 69d3aa4017..487418e694 100644
--- a/tensorflow/python/kernel_tests/linalg/BUILD
+++ b/tensorflow/python/kernel_tests/linalg/BUILD
@@ -197,7 +197,7 @@ cuda_py_test(
cuda_py_test(
name = "linear_operator_low_rank_update_test",
- size = "medium",
+ size = "large",
srcs = ["linear_operator_low_rank_update_test.py"],
additional_deps = [
"//tensorflow/python/ops/linalg",
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py
index 34b35a4ffb..0e38dbd48d 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py
@@ -49,12 +49,6 @@ class BaseLinearOperatorLowRankUpdatetest(object):
_use_v = None
@property
- def _dtypes_to_test(self):
- # TODO(langmore) Test complex types once cholesky works with them.
- # See comment in LinearOperatorLowRankUpdate.__init__.
- return [dtypes.float32, dtypes.float64]
-
- @property
def _operator_build_infos(self):
build_info = linear_operator_test_util.OperatorBuildInfo
# Previously we had a (2, 10, 10) shape at the end. We did this to test the
@@ -68,6 +62,15 @@ class BaseLinearOperatorLowRankUpdatetest(object):
build_info((3, 4, 4)),
build_info((2, 1, 4, 4))]
+ def _gen_positive_diag(self, dtype, diag_shape):
+ if dtype.is_complex:
+ diag = linear_operator_test_util.random_uniform(
+ diag_shape, minval=1e-4, maxval=1., dtype=dtypes.float32)
+ return math_ops.cast(diag, dtype=dtype)
+
+ return linear_operator_test_util.random_uniform(
+ diag_shape, minval=1e-4, maxval=1., dtype=dtype)
+
def _operator_and_matrix(self, build_info, dtype, use_placeholder):
# Recall A = L + UDV^H
shape = list(build_info.shape)
@@ -78,8 +81,7 @@ class BaseLinearOperatorLowRankUpdatetest(object):
# base_operator L will be a symmetric positive definite diagonal linear
# operator, with condition number as high as 1e4.
- base_diag = linear_operator_test_util.random_uniform(
- diag_shape, minval=1e-4, maxval=1., dtype=dtype)
+ base_diag = self._gen_positive_diag(dtype, diag_shape)
lin_op_base_diag = base_diag
# U
@@ -94,8 +96,7 @@ class BaseLinearOperatorLowRankUpdatetest(object):
# D
if self._is_diag_update_positive:
- diag_update = linear_operator_test_util.random_uniform(
- diag_update_shape, minval=1e-4, maxval=1., dtype=dtype)
+ diag_update = self._gen_positive_diag(dtype, diag_update_shape)
else:
diag_update = linear_operator_test_util.random_normal(
diag_update_shape, stddev=1e-4, dtype=dtype)
@@ -110,7 +111,9 @@ class BaseLinearOperatorLowRankUpdatetest(object):
diag_update, shape=None)
base_operator = linalg.LinearOperatorDiag(
- lin_op_base_diag, is_positive_definite=True)
+ lin_op_base_diag,
+ is_positive_definite=True,
+ is_self_adjoint=True)
operator = linalg.LinearOperatorLowRankUpdate(
base_operator,
@@ -169,6 +172,7 @@ class LinearOperatorLowRankUpdatetestWithDiagUseCholesky(
self._rtol[dtypes.float32] = 1e-5
self._atol[dtypes.float64] = 1e-10
self._rtol[dtypes.float64] = 1e-10
+ self._rtol[dtypes.complex64] = 1e-4
class LinearOperatorLowRankUpdatetestWithDiagCannotUseCholesky(
@@ -188,6 +192,7 @@ class LinearOperatorLowRankUpdatetestWithDiagCannotUseCholesky(
self._rtol[dtypes.float32] = 1e-4
self._atol[dtypes.float64] = 1e-9
self._rtol[dtypes.float64] = 1e-9
+ self._rtol[dtypes.complex64] = 1e-4
class LinearOperatorLowRankUpdatetestNoDiagUseCholesky(
@@ -206,6 +211,7 @@ class LinearOperatorLowRankUpdatetestNoDiagUseCholesky(
self._rtol[dtypes.float32] = 1e-5
self._atol[dtypes.float64] = 1e-10
self._rtol[dtypes.float64] = 1e-10
+ self._rtol[dtypes.complex64] = 1e-4
class LinearOperatorLowRankUpdatetestNoDiagCannotUseCholesky(
@@ -225,6 +231,7 @@ class LinearOperatorLowRankUpdatetestNoDiagCannotUseCholesky(
self._rtol[dtypes.float32] = 1e-4
self._atol[dtypes.float64] = 1e-9
self._rtol[dtypes.float64] = 1e-9
+ self._rtol[dtypes.complex64] = 1e-4
class LinearOperatorLowRankUpdatetestWithDiagNotSquare(
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py
index 167c6cacd1..b389e0cbdf 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py
@@ -17,7 +17,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.framework import dtypes
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.linalg import linalg as linalg_lib
@@ -32,12 +31,6 @@ class LinearOperatorLowerTriangularTest(
linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
"""Most tests done in the base class LinearOperatorDerivedClassTest."""
- @property
- def _dtypes_to_test(self):
- # TODO(langmore) Test complex types once supported by
- # matrix_triangular_solve.
- return [dtypes.float32, dtypes.float64]
-
def _operator_and_matrix(self, build_info, dtype, use_placeholder):
shape = list(build_info.shape)
# Upper triangle will be nonzero, but ignored.
diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
index 0fb0b8895c..e358293a90 100644
--- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py
+++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
@@ -852,5 +852,62 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
state_ops.scatter_update(v, [0, 1], [0, 1, 2])
+class _MixedPrecisionVariableTest(test_util.TensorFlowTestCase):
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_dense_var_to_tensor_read_dtype_same_as_var_dtype(self):
+ # read_dtype is same as dtype
+ v = resource_variable_ops.ResourceVariable(1.0, dtype=dtypes.float32)
+ v = resource_variable_ops._MixedPrecisionVariable(v, dtypes.float32)
+ if not context.executing_eagerly():
+ v.initializer.run()
+
+ # dtype is not read_dtype, return NotImplemented
+ self.assertEqual(
+ NotImplemented, v._dense_var_to_tensor(dtype=dtypes.float16))
+ self.assertEqual(NotImplemented,
+ v._dense_var_to_tensor(dtype=dtypes.float16, as_ref=True))
+
+ # as_ref is False
+ t = v._dense_var_to_tensor(as_ref=False)
+ self.assertTrue(isinstance(t, ops.Tensor))
+ self.assertEqual(t.dtype, dtypes.float32)
+ self.assertEqual(self.evaluate(t), 1.0)
+
+ t = v._dense_var_to_tensor(dtype=dtypes.float32, as_ref=False)
+ self.assertTrue(isinstance(t, ops.Tensor))
+ self.assertEqual(t.dtype, dtypes.float32)
+ self.assertEqual(self.evaluate(t), 1.0)
+
+ # as_ref is True
+ self.assertEqual(NotImplemented, v._dense_var_to_tensor(as_ref=True))
+ self.assertEqual(NotImplemented,
+ v._dense_var_to_tensor(dtype=dtypes.float32, as_ref=True))
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_dense_var_to_tensor_read_dtype_different_from_var_dtype(self):
+ # read_dtype is different from dtype
+ v = resource_variable_ops.ResourceVariable(1.0, dtype=dtypes.float32)
+ v = resource_variable_ops._MixedPrecisionVariable(v, dtypes.float16)
+ if not context.executing_eagerly():
+ v.initializer.run()
+
+ # as_ref is False
+ t = v._dense_var_to_tensor(as_ref=False)
+ self.assertTrue(isinstance(t, ops.Tensor))
+ self.assertEqual(t.dtype, dtypes.float16)
+ self.assertEqual(self.evaluate(t), 1.0)
+
+ t = v._dense_var_to_tensor(dtype=dtypes.float16, as_ref=False)
+ self.assertTrue(isinstance(t, ops.Tensor))
+ self.assertEqual(t.dtype, dtypes.float16)
+ self.assertEqual(self.evaluate(t), 1.0)
+
+ # as_ref is True
+ self.assertEqual(NotImplemented, v._dense_var_to_tensor(as_ref=True))
+ self.assertEqual(NotImplemented,
+ v._dense_var_to_tensor(dtype=dtypes.float16, as_ref=True))
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py
index 957baf8c60..acee180a6c 100644
--- a/tensorflow/python/kernel_tests/rnn_test.py
+++ b/tensorflow/python/kernel_tests/rnn_test.py
@@ -268,6 +268,12 @@ class RNNTest(test.TestCase):
self._assert_cell_builds(rnn_cell_impl.GRUCell, f64, 5, 7, 3)
self._assert_cell_builds(rnn_cell_impl.LSTMCell, f32, 5, 7, 3)
self._assert_cell_builds(rnn_cell_impl.LSTMCell, f64, 5, 7, 3)
+ self._assert_cell_builds(contrib_rnn.IndRNNCell, f32, 5, 7, 3)
+ self._assert_cell_builds(contrib_rnn.IndRNNCell, f64, 5, 7, 3)
+ self._assert_cell_builds(contrib_rnn.IndyGRUCell, f32, 5, 7, 3)
+ self._assert_cell_builds(contrib_rnn.IndyGRUCell, f64, 5, 7, 3)
+ self._assert_cell_builds(contrib_rnn.IndyLSTMCell, f32, 5, 7, 3)
+ self._assert_cell_builds(contrib_rnn.IndyLSTMCell, f64, 5, 7, 3)
######### Benchmarking RNN code
diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py
index 054c6f9dd7..ae2a0ab29a 100644
--- a/tensorflow/python/kernel_tests/variable_scope_test.py
+++ b/tensorflow/python/kernel_tests/variable_scope_test.py
@@ -1054,7 +1054,7 @@ class VariableScopeTest(test.TestCase):
"testGetCollection_foo/testGetCollection_a:0"
])
- def testGetTrainableVariables(self):
+ def testGetTrainableVariablesWithGetVariable(self):
with self.test_session():
_ = variable_scope.get_variable("testGetTrainableVariables_a", [])
with variable_scope.variable_scope(
@@ -1062,10 +1062,72 @@ class VariableScopeTest(test.TestCase):
_ = variable_scope.get_variable("testGetTrainableVariables_b", [])
_ = variable_scope.get_variable(
"testGetTrainableVariables_c", [], trainable=False)
+
+ # sync `ON_READ` sets trainable=False
+ _ = variable_scope.get_variable(
+ "testGetTrainableVariables_d", [],
+ synchronization=variable_scope.VariableSynchronization.ON_READ)
self.assertEqual(
[v.name for v in scope.trainable_variables()],
- ["testGetTrainableVariables_foo/"
- "testGetTrainableVariables_b:0"])
+ ["testGetTrainableVariables_foo/testGetTrainableVariables_b:0"])
+
+ # All other sync values sets trainable=True
+ _ = variable_scope.get_variable(
+ "testGetTrainableVariables_e", [],
+ synchronization=variable_scope.VariableSynchronization.ON_WRITE)
+ self.assertEqual([v.name for v in scope.trainable_variables()], [
+ "testGetTrainableVariables_foo/testGetTrainableVariables_b:0",
+ "testGetTrainableVariables_foo/testGetTrainableVariables_e:0"
+ ])
+
+ with self.assertRaisesRegexp(
+ ValueError, "Synchronization value can be set to "
+ "VariableSynchronization.ON_READ only for non-trainable variables. "
+ "You have specified trainable=True and "
+ "synchronization=VariableSynchronization.ON_READ."):
+ _ = variable_scope.get_variable(
+ "testGetTrainableVariables_e", [],
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ trainable=True)
+
+ def testGetTrainableVariablesWithVariable(self):
+ with self.test_session():
+ _ = variable_scope.variable(1.0, name="testGetTrainableVariables_a")
+ with variable_scope.variable_scope(
+ "testGetTrainableVariables_foo") as scope:
+ _ = variable_scope.variable(1.0, name="testGetTrainableVariables_b")
+ _ = variable_scope.variable(
+ 1.0, name="testGetTrainableVariables_c", trainable=False)
+
+ # sync `ON_READ` sets trainable=False
+ _ = variable_scope.variable(
+ 1.0,
+ name="testGetTrainableVariables_d",
+ synchronization=variable_scope.VariableSynchronization.ON_READ)
+ self.assertEqual(
+ [v.name for v in scope.trainable_variables()],
+ ["testGetTrainableVariables_foo/testGetTrainableVariables_b:0"])
+
+ # All other sync values sets trainable=True
+ _ = variable_scope.variable(
+ 1.0,
+ name="testGetTrainableVariables_e",
+ synchronization=variable_scope.VariableSynchronization.ON_WRITE)
+ self.assertEqual([v.name for v in scope.trainable_variables()], [
+ "testGetTrainableVariables_foo/testGetTrainableVariables_b:0",
+ "testGetTrainableVariables_foo/testGetTrainableVariables_e:0"
+ ])
+
+ with self.assertRaisesRegexp(
+ ValueError, "Synchronization value can be set to "
+ "VariableSynchronization.ON_READ only for non-trainable variables. "
+ "You have specified trainable=True and "
+ "synchronization=VariableSynchronization.ON_READ."):
+ _ = variable_scope.variable(
+ 1.0,
+ name="testGetTrainableVariables_e",
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ trainable=True)
def testGetGlobalVariables(self):
with self.test_session():
diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py
index 62d596da91..2b9c62ad6f 100644
--- a/tensorflow/python/kernel_tests/variables_test.py
+++ b/tensorflow/python/kernel_tests/variables_test.py
@@ -642,6 +642,8 @@ class PartitionedVariableTest(test.TestCase):
iterated_partitions = list(partitioned_variable)
self.assertEqual(2, num_partitions)
self.assertEqual([v0, v1], iterated_partitions)
+ self.assertEqual([2], partitioned_variable.get_shape())
+ self.assertEqual([2], partitioned_variable.shape)
self.assertEqual([2], concatenated.get_shape())
self.assertEqual([2], concatenated.shape)
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py
index b8969a41ab..cf13b52617 100644
--- a/tensorflow/python/layers/base.py
+++ b/tensorflow/python/layers/base.py
@@ -152,10 +152,17 @@ class Layer(base_layer.Layer):
scope, default_name=self._base_name) as captured_scope:
self._scope = captured_scope
- def add_weight(self, name, shape, dtype=None,
- initializer=None, regularizer=None,
- trainable=True, constraint=None,
+ def add_weight(self,
+ name,
+ shape,
+ dtype=None,
+ initializer=None,
+ regularizer=None,
+ trainable=None,
+ constraint=None,
use_resource=None,
+ synchronization=vs.VariableSynchronization.AUTO,
+ aggregation=vs.VariableAggregation.NONE,
partitioner=None):
"""Adds a new variable to the layer, or gets an existing one; returns it.
@@ -170,9 +177,19 @@ class Layer(base_layer.Layer):
or "non_trainable_variables" (e.g. BatchNorm mean, stddev).
Note, if the current variable scope is marked as non-trainable
then this parameter is ignored and any added variables are also
- marked as non-trainable.
+ marked as non-trainable. `trainable` defaults to `True` unless
+ `synchronization` is set to `ON_READ`.
constraint: constraint instance (callable).
use_resource: Whether to use `ResourceVariable`.
+ synchronization: Indicates when a distributed a variable will be
+ aggregated. Accepted values are constants defined in the class
+ @{tf.VariableSynchronization}. By default the synchronization is set to
+ `AUTO` and the current `DistributionStrategy` chooses
+ when to synchronize. If `synchronization` is set to `ON_READ`,
+ `trainable` must not be set to `True`.
+ aggregation: Indicates how a distributed variable will be aggregated.
+ Accepted values are constants defined in the class
+ @{tf.VariableAggregation}.
partitioner: (optional) partitioner instance (callable). If
provided, when the requested variable is created it will be split
into multiple partitions according to `partitioner`. In this case,
@@ -190,7 +207,21 @@ class Layer(base_layer.Layer):
Raises:
RuntimeError: If called with partioned variable regularization and
eager execution is enabled.
+ ValueError: When trainable has been set to True with synchronization
+ set as `ON_READ`.
"""
+ if synchronization == vs.VariableSynchronization.ON_READ:
+ if trainable:
+ raise ValueError(
+ 'Synchronization value can be set to '
+ 'VariableSynchronization.ON_READ only for non-trainable variables. '
+ 'You have specified trainable=True and '
+ 'synchronization=VariableSynchronization.ON_READ.')
+ else:
+ # Set trainable to be false when variable is to be synced on read.
+ trainable = False
+ elif trainable is None:
+ trainable = True
def _should_add_regularizer(variable, existing_variable_set):
if isinstance(variable, tf_variables.PartitionedVariable):
@@ -240,6 +271,8 @@ class Layer(base_layer.Layer):
constraint=constraint,
partitioner=partitioner,
use_resource=use_resource,
+ synchronization=synchronization,
+ aggregation=aggregation,
getter=vs.get_variable)
if regularizer:
diff --git a/tensorflow/python/layers/base_test.py b/tensorflow/python/layers/base_test.py
index 298e96e711..d2443db665 100644
--- a/tensorflow/python/layers/base_test.py
+++ b/tensorflow/python/layers/base_test.py
@@ -90,12 +90,34 @@ class BaseLayerTest(test.TestCase):
# regularizers only supported in GRAPH mode.
regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3
- variable = layer.add_variable(
+ _ = layer.add_variable(
'reg_var', [2, 2],
initializer=init_ops.zeros_initializer(),
regularizer=regularizer)
self.assertEqual(len(layer.losses), 1)
+ # Test that sync `ON_READ` variables are defaulted to be non-trainable.
+ variable_3 = layer.add_variable(
+ 'sync_on_read_var', [2, 2],
+ initializer=init_ops.zeros_initializer(),
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ aggregation=variable_scope.VariableAggregation.SUM)
+ self.assertEqual(layer.non_trainable_variables, [variable_2, variable_3])
+
+ def testInvalidTrainableSynchronizationCombination(self):
+ layer = base_layers.Layer(name='my_layer')
+
+ with self.assertRaisesRegexp(
+ ValueError, 'Synchronization value can be set to '
+ 'VariableSynchronization.ON_READ only for non-trainable variables. '
+ 'You have specified trainable=True and '
+ 'synchronization=VariableSynchronization.ON_READ.'):
+ _ = layer.add_variable(
+ 'v', [2, 2],
+ initializer=init_ops.zeros_initializer(),
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ trainable=True)
+
def testReusePartitionedVaraiblesAndRegularizers(self):
regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3
partitioner = partitioned_variables.fixed_size_partitioner(3)
@@ -104,7 +126,7 @@ class BaseLayerTest(test.TestCase):
partitioner=partitioner,
reuse=reuse):
layer = base_layers.Layer(name='my_layer')
- variable = layer.add_variable(
+ _ = layer.add_variable(
'reg_part_var', [4, 4],
initializer=init_ops.zeros_initializer(),
regularizer=regularizer)
diff --git a/tensorflow/python/lib/core/numpy.h b/tensorflow/python/lib/core/numpy.h
index d4621d61ee..0098d938a0 100644
--- a/tensorflow/python/lib/core/numpy.h
+++ b/tensorflow/python/lib/core/numpy.h
@@ -30,9 +30,10 @@ limitations under the License.
#endif
// Place `<locale>` before <Python.h> to avoid build failure in macOS.
-#include <Python.h>
#include <locale>
+#include <Python.h>
+
#include "numpy/arrayobject.h"
#include "numpy/ufuncobject.h"
diff --git a/tensorflow/python/lib/core/py_util.cc b/tensorflow/python/lib/core/py_util.cc
index 6b6c82015f..2ee898ea1d 100644
--- a/tensorflow/python/lib/core/py_util.cc
+++ b/tensorflow/python/lib/core/py_util.cc
@@ -16,9 +16,10 @@ limitations under the License.
#include "tensorflow/python/lib/core/py_util.h"
// Place `<locale>` before <Python.h> to avoid build failure in macOS.
-#include <Python.h>
#include <locale>
+#include <Python.h>
+
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/strcat.h"
diff --git a/tensorflow/python/ops/functional_ops.py b/tensorflow/python/ops/functional_ops.py
index 53ae6d843f..4ecc74675a 100644
--- a/tensorflow/python/ops/functional_ops.py
+++ b/tensorflow/python/ops/functional_ops.py
@@ -775,7 +775,7 @@ def While(input_, cond, body, name=None, hostmem=None):
a string, non-empty means True and empty means False. If the
tensor is not a scalar, non-emptiness means True and False
otherwise.
- body: . A funcion takes a list of tensors and returns another
+ body: . A function takes a list of tensors and returns another
list tensors. Both lists have the same types as specified
by T.
name: A name for the operation (optional).
diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py
index a2eae452ae..9440bab9ee 100644
--- a/tensorflow/python/ops/image_ops_impl.py
+++ b/tensorflow/python/ops/image_ops_impl.py
@@ -55,6 +55,7 @@ ops.NotDifferentiable('SampleDistortedBoundingBoxV2')
ops.NotDifferentiable('ExtractGlimpse')
ops.NotDifferentiable('NonMaxSuppression')
ops.NotDifferentiable('NonMaxSuppressionV2')
+ops.NotDifferentiable('NonMaxSuppressionWithOverlaps')
# pylint: disable=invalid-name
@@ -1752,6 +1753,22 @@ def is_jpeg(contents, name=None):
return math_ops.equal(substr, b'\xff\xd8\xff', name=name)
+def _is_png(contents, name=None):
+ r"""Convenience function to check if the 'contents' encodes a PNG image.
+
+ Args:
+ contents: 0-D `string`. The encoded image bytes.
+ name: A name for the operation (optional)
+
+ Returns:
+ A scalar boolean tensor indicating if 'contents' may be a PNG image.
+ is_png is susceptible to false positives.
+ """
+ with ops.name_scope(name, 'is_png'):
+ substr = string_ops.substr(contents, 0, 3)
+ return math_ops.equal(substr, b'\211PN', name=name)
+
+
@tf_export('image.decode_image')
def decode_image(contents, channels=None, dtype=dtypes.uint8, name=None):
"""Convenience function for `decode_bmp`, `decode_gif`, `decode_jpeg`,
@@ -1829,8 +1846,8 @@ def decode_image(contents, channels=None, dtype=dtypes.uint8, name=None):
def check_png():
"""Checks if an image is PNG."""
- is_png = math_ops.equal(substr, b'\211PN', name='is_png')
- return control_flow_ops.cond(is_png, _png, check_gif, name='cond_png')
+ return control_flow_ops.cond(
+ _is_png(contents), _png, check_gif, name='cond_png')
def _jpeg():
"""Decodes a jpeg image."""
@@ -2093,6 +2110,50 @@ def non_max_suppression(boxes,
iou_threshold, score_threshold)
+@tf_export('image.non_max_suppression_overlaps')
+def non_max_suppression_with_overlaps(overlaps,
+ scores,
+ max_output_size,
+ overlap_threshold=0.5,
+ score_threshold=float('-inf'),
+ name=None):
+ """Greedily selects a subset of bounding boxes in descending order of score.
+
+ Prunes away boxes that have high overlap with previously selected boxes.
+ N-by-n overlap values are supplied as square matrix.
+ The output of this operation is a set of integers indexing into the input
+ collection of bounding boxes representing the selected boxes. The bounding
+ box coordinates corresponding to the selected indices can then be obtained
+ using the `tf.gather operation`. For example:
+ selected_indices = tf.image.non_max_suppression_overlaps(
+ overlaps, scores, max_output_size, iou_threshold)
+ selected_boxes = tf.gather(boxes, selected_indices)
+
+ Args:
+ overlaps: A 2-D float `Tensor` of shape `[num_boxes, num_boxes]`.
+ scores: A 1-D float `Tensor` of shape `[num_boxes]` representing a single
+ score corresponding to each box (each row of boxes).
+ max_output_size: A scalar integer `Tensor` representing the maximum number
+ of boxes to be selected by non max suppression.
+ overlap_threshold: A float representing the threshold for deciding whether
+ boxes overlap too much with respect to the provided overlap values.
+ score_threshold: A float representing the threshold for deciding when to
+ remove boxes based on score.
+ name: A name for the operation (optional).
+
+ Returns:
+ selected_indices: A 1-D integer `Tensor` of shape `[M]` representing the
+ selected indices from the overlaps tensor, where `M <= max_output_size`.
+ """
+ with ops.name_scope(name, 'non_max_suppression_overlaps'):
+ overlap_threshold = ops.convert_to_tensor(
+ overlap_threshold, name='overlap_threshold')
+ # pylint: disable=protected-access
+ return gen_image_ops._non_max_suppression_v3(
+ overlaps, scores, max_output_size, overlap_threshold, score_threshold)
+ # pylint: enable=protected-access
+
+
_rgb_to_yiq_kernel = [[0.299, 0.59590059,
0.2115], [0.587, -0.27455667, -0.52273617],
[0.114, -0.32134392, 0.31119955]]
diff --git a/tensorflow/python/ops/init_ops.py b/tensorflow/python/ops/init_ops.py
index 5bfc5ce2a7..3132f7467f 100644
--- a/tensorflow/python/ops/init_ops.py
+++ b/tensorflow/python/ops/init_ops.py
@@ -1136,7 +1136,7 @@ convolutional_orthogonal_3d = ConvolutionOrthogonal3D
# pylint: enable=invalid-name
-@tf_export("glorot_uniform_initializer")
+@tf_export("glorot_uniform_initializer", "keras.initializers.glorot_uniform")
def glorot_uniform_initializer(seed=None, dtype=dtypes.float32):
"""The Glorot uniform initializer, also called Xavier uniform initializer.
@@ -1160,7 +1160,7 @@ def glorot_uniform_initializer(seed=None, dtype=dtypes.float32):
scale=1.0, mode="fan_avg", distribution="uniform", seed=seed, dtype=dtype)
-@tf_export("glorot_normal_initializer")
+@tf_export("glorot_normal_initializer", "keras.initializers.glorot_normal")
def glorot_normal_initializer(seed=None, dtype=dtypes.float32):
"""The Glorot normal initializer, also called Xavier normal initializer.
diff --git a/tensorflow/python/ops/linalg/linear_operator_diag.py b/tensorflow/python/ops/linalg/linear_operator_diag.py
index 5beaea65a5..ed53decc00 100644
--- a/tensorflow/python/ops/linalg/linear_operator_diag.py
+++ b/tensorflow/python/ops/linalg/linear_operator_diag.py
@@ -231,8 +231,11 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
return math_ops.reduce_prod(self._diag, reduction_indices=[-1])
def _log_abs_determinant(self):
- return math_ops.reduce_sum(
+ log_det = math_ops.reduce_sum(
math_ops.log(math_ops.abs(self._diag)), reduction_indices=[-1])
+ if self.dtype.is_complex:
+ log_det = math_ops.cast(log_det, dtype=self.dtype)
+ return log_det
def _solve(self, rhs, adjoint=False, adjoint_arg=False):
diag_term = math_ops.conj(self._diag) if adjoint else self._diag
diff --git a/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py b/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py
index 08e5896e10..2b2bf80f27 100644
--- a/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py
+++ b/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py
@@ -18,16 +18,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import check_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.linalg import linear_operator
from tensorflow.python.ops.linalg import linear_operator_diag
from tensorflow.python.ops.linalg import linear_operator_identity
from tensorflow.python.ops.linalg import linear_operator_util
+from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util.tf_export import tf_export
__all__ = [
@@ -153,8 +152,7 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator):
`is_X` matrix property hints, which will trigger the appropriate code path.
Args:
- base_operator: Shape `[B1,...,Bb, M, N]` real `float16`, `float32` or
- `float64` `LinearOperator`. This is `L` above.
+ base_operator: Shape `[B1,...,Bb, M, N]`.
u: Shape `[B1,...,Bb, M, K]` `Tensor` of same `dtype` as `base_operator`.
This is `U` above.
diag_update: Optional shape `[B1,...,Bb, K]` `Tensor` with same `dtype`
@@ -183,23 +181,12 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator):
Raises:
ValueError: If `is_X` flags are set in an inconsistent way.
"""
- # TODO(langmore) support complex types.
- # Complex types are not allowed due to tf.cholesky() requiring float.
- # If complex dtypes are allowed, we update the following
- # 1. is_diag_update_positive should still imply that `diag > 0`, but we need
- # to remind the user that this implies diag is real. This is needed
- # because if diag has non-zero imaginary part, it will not be
- # self-adjoint positive definite.
dtype = base_operator.dtype
- allowed_dtypes = [
- dtypes.float16,
- dtypes.float32,
- dtypes.float64,
- ]
- if dtype not in allowed_dtypes:
- raise TypeError(
- "Argument matrix must have dtype in %s. Found: %s"
- % (allowed_dtypes, dtype))
+
+ if diag_update is not None:
+ if is_diag_update_positive and dtype.is_complex:
+ logging.warn("Note: setting is_diag_update_positive with a complex "
+ "dtype means that diagonal is real and positive.")
if diag_update is None:
if is_diag_update_positive is False:
@@ -271,8 +258,6 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator):
self._set_diag_operators(diag_update, is_diag_update_positive)
self._is_diag_update_positive = is_diag_update_positive
- check_ops.assert_same_float_dtype((base_operator, self.u, self.v,
- self._diag_update))
self._check_shapes()
# Pre-compute the so-called "capacitance" matrix
@@ -407,6 +392,8 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator):
else:
det_c = linalg_ops.matrix_determinant(self._capacitance)
log_abs_det_c = math_ops.log(math_ops.abs(det_c))
+ if self.dtype.is_complex:
+ log_abs_det_c = math_ops.cast(log_abs_det_c, dtype=self.dtype)
return log_abs_det_c + log_abs_det_d + log_abs_det_l
diff --git a/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py b/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py
index fb1eb2fedb..ca6d3f5405 100644
--- a/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py
+++ b/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py
@@ -119,8 +119,7 @@ class LinearOperatorLowerTriangular(linear_operator.LinearOperator):
Args:
tril: Shape `[B1,...,Bb, N, N]` with `b >= 0`, `N >= 0`.
The lower triangular part of `tril` defines this operator. The strictly
- upper triangle is ignored. Allowed dtypes: `float16`, `float32`,
- `float64`.
+ upper triangle is ignored.
is_non_singular: Expect that this operator is non-singular.
This operator is non-singular if and only if its diagonal elements are
all non-zero.
@@ -137,7 +136,6 @@ class LinearOperatorLowerTriangular(linear_operator.LinearOperator):
name: A name for this `LinearOperator`.
Raises:
- TypeError: If `diag.dtype` is not an allowed type.
ValueError: If `is_square` is `False`.
"""
@@ -163,12 +161,12 @@ class LinearOperatorLowerTriangular(linear_operator.LinearOperator):
def _check_tril(self, tril):
"""Static check of the `tril` argument."""
- # TODO(langmore) Add complex types once matrix_triangular_solve works for
- # them.
allowed_dtypes = [
dtypes.float16,
dtypes.float32,
dtypes.float64,
+ dtypes.complex64,
+ dtypes.complex128,
]
dtype = tril.dtype
if dtype not in allowed_dtypes:
diff --git a/tensorflow/python/ops/linalg_ops.py b/tensorflow/python/ops/linalg_ops.py
index a0dfa543f9..f4a93560be 100644
--- a/tensorflow/python/ops/linalg_ops.py
+++ b/tensorflow/python/ops/linalg_ops.py
@@ -401,7 +401,7 @@ def svd(tensor, full_matrices=False, compute_uv=True, name=None):
import tensorflow as tf
import numpy as np
s, u, v = tf.linalg.svd(a)
- tf_a_approx = tf.matmul(u, tf.matmul(tf.linalg.diag(s), v, adjoint_v=True))
+ tf_a_approx = tf.matmul(u, tf.matmul(tf.linalg.diag(s), v, adjoint_b=True))
u, s, v_adj = np.linalg.svd(a, full_matrices=False)
np_a_approx = np.dot(u, np.dot(np.diag(s), v_adj))
# tf_a_approx and np_a_approx should be numerically close.
diff --git a/tensorflow/python/ops/logging_ops.py b/tensorflow/python/ops/logging_ops.py
index 8276047cb6..df41933f8a 100644
--- a/tensorflow/python/ops/logging_ops.py
+++ b/tensorflow/python/ops/logging_ops.py
@@ -35,9 +35,12 @@ from tensorflow.python.util.tf_export import tf_export
# Assert and Print are special symbols in python, so we must
-# have an upper-case version of them. For users with Python 3 or Python 2.7
-# with `from __future__ import print_function`, we also allow lowercase.
-@tf_export("Print", "print")
+# have an upper-case version of them.
+#
+# For users with Python 3 or Python 2.7
+# with `from __future__ import print_function`, we could also allow lowercase.
+# See https://github.com/tensorflow/tensorflow/issues/18053
+@tf_export("Print")
def Print(input_, data, message=None, first_n=None, summarize=None,
name=None):
"""Prints a list of tensors.
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index cdb6dc8f22..c28dca5137 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -37,11 +37,11 @@ from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import gen_sparse_ops
from tensorflow.python.ops import gen_spectral_ops
-from tensorflow.python.platform import tf_logging as logging
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_math_ops import *
# pylint: enable=wildcard-import
+from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
from tensorflow.python.util import deprecation
from tensorflow.python.util import nest
@@ -651,6 +651,9 @@ def cast(x, dtype, name=None):
TypeError: If `x` cannot be cast to the `dtype`.
"""
base_type = dtypes.as_dtype(dtype).base_dtype
+ if isinstance(x,
+ (ops.Tensor, _resource_variable_type)) and base_type == x.dtype:
+ return x
with ops.name_scope(name, "Cast", [x]) as name:
if isinstance(x, sparse_tensor.SparseTensor):
values_cast = cast(x.values, base_type, name=name)
@@ -1222,8 +1225,9 @@ def _ReductionDims(x, axis, reduction_indices):
return axis
else:
# Fast path: avoid creating Rank and Range ops if ndims is known.
- if isinstance(x, ops.Tensor) and x._rank() is not None: # pylint: disable=protected-access
- return constant_op.constant(np.arange(x._rank()), dtype=dtypes.int32) # pylint: disable=protected-access
+ rank = common_shapes.rank(x)
+ if rank is not None:
+ return constant_op.constant(np.arange(rank), dtype=dtypes.int32)
if (isinstance(x, sparse_tensor.SparseTensor) and
x.dense_shape.get_shape().is_fully_defined()):
rank = x.dense_shape.get_shape()[0].value # sparse.dense_shape is 1-D.
@@ -1234,8 +1238,8 @@ def _ReductionDims(x, axis, reduction_indices):
def _may_reduce_to_scalar(keepdims, axis, reduction_indices, output):
- """Set a reduction's output's shape to be a scalar if we are certain."""
- if (not output.shape.is_fully_defined()) and (not keepdims) and (
+ """Set a reduction's output shape to be a scalar if we are certain."""
+ if not common_shapes.has_fully_defined_shape(output) and (not keepdims) and (
axis is None) and (reduction_indices is None):
output.set_shape(())
return output
diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py
index bfd225b0d8..3aedeb6acd 100644
--- a/tensorflow/python/ops/metrics_impl.py
+++ b/tensorflow/python/ops/metrics_impl.py
@@ -73,16 +73,16 @@ def metric_variable(shape, dtype, validate_shape=True, name=None):
A (non-trainable) variable initialized to zero, or if inside a
`DistributionStrategy` scope a tower-local variable container.
"""
- with distribute_lib.get_tower_context().tower_local_var_scope(
- variable_scope.VariableAggregation.SUM):
- # Note that "tower local" implies trainable=False.
- return variable_scope.variable(
- lambda: array_ops.zeros(shape, dtype),
- collections=[
- ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES
- ],
- validate_shape=validate_shape,
- name=name)
+ # Note that synchronization "ON_READ" implies trainable=False.
+ return variable_scope.variable(
+ lambda: array_ops.zeros(shape, dtype),
+ collections=[
+ ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES
+ ],
+ validate_shape=validate_shape,
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ aggregation=variable_scope.VariableAggregation.SUM,
+ name=name)
def _remove_squeezable_dimensions(predictions, labels, weights):
diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py
index ec4ef0f1ab..77ec3bc0d4 100644
--- a/tensorflow/python/ops/parallel_for/pfor.py
+++ b/tensorflow/python/ops/parallel_for/pfor.py
@@ -592,7 +592,7 @@ class WhileOp(object):
inputs = args[:num_enters]
output_tas = args[num_enters:]
# TODO(agarwal): see which outputs have consumers and only populate the
- # TensorArrays corresonding to those. Or do those paths get trimmed out
+ # TensorArrays corresponding to those. Or do those paths get trimmed out
# from inside the while_loop body?
assert len(inputs) >= len(output_tas)
assert len(inputs) == len(inputs_stacked)
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index 15cafbbde5..70a89e5ebb 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -867,6 +867,19 @@ class ResourceVariable(variables.Variable):
__array_priority__ = 100
+ def is_initialized(self, name=None):
+ """Checks whether a resource variable has been initialized.
+
+ Outputs boolean scalar indicating whether the tensor has been initialized.
+
+ Args:
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` of type `bool`.
+ """
+ return gen_resource_variable_ops.var_is_initialized_op(self.handle, name)
+
def assign_sub(self, delta, use_locking=None, name=None, read_value=True):
"""Subtracts a value from this variable.
@@ -1091,6 +1104,113 @@ class _UnreadVariable(ResourceVariable):
ops.register_tensor_conversion_function(_UnreadVariable, _dense_var_to_tensor)
ops.register_dense_tensor_like_type(_UnreadVariable)
+
+class _MixedPrecisionVariable(ResourceVariable):
+ """Represents a variable that can return in desired dtype when read.
+
+ In mixed precision training, it is usually desirable to use different dtypes
+ for variables and computation. This class will be used to wrap created
+ ResourceVariable when mixed precision training is enabled. It allows layers to
+ perform computation in a different dtype than their variable dtypes, in order
+ to achieve higher performance without causing quality loss.
+ """
+
+ def __init__(self, var, read_dtype):
+ """Creates a MixedPrecisionVariable.
+
+ Args:
+ var: A ResourceVariable instance.
+ read_dtype: A tf.DType, the returned dtype when read, default to None.
+ Casting is performed if read_dtype is not None and differs from
+ var.dtype.
+ Returns:
+ An MixedPrecisionVariable instance.
+ Raises:
+ ValueError: if var is not a ResourceVariable instance, or read_dtype is
+ not a tf.DType instance.
+ """
+ # pylint: disable=super-init-not-called
+ # We do not call super init on purpose.
+ if not isinstance(var, ResourceVariable):
+ raise ValueError("InvalidArgument: var must be a ResourceVariable type.")
+ if not isinstance(read_dtype, dtypes.DType):
+ raise ValueError("InvalidArgument: read_dtype must be a tf.DType type.")
+
+ self._var = var
+ self._trainable = var.trainable
+ self._save_slice_info = None
+ self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access
+ self._in_graph_mode = var._in_graph_mode # pylint: disable=protected-access
+ self._handle = var.handle
+ self._shape = var.shape
+ self._initial_value = None
+ if isinstance(self.handle, ops.EagerTensor):
+ self._handle_name = ""
+ else:
+ self._handle_name = self.handle.name
+ self._unique_id = var._unique_id # pylint: disable=protected-access
+ self._dtype = var.dtype
+ self._constraint = None
+ self._cached_value = None
+ self._is_initialized_op = var._is_initialized_op # pylint: disable=protected-access
+ self._initializer_op = var._initializer_op # pylint: disable=protected-access
+ # This needs to be set before read_value() is called.
+ self._read_dtype = read_dtype
+ if context.executing_eagerly():
+ self._graph_element = None
+ else:
+ self._graph_element = self.read_value()
+ self._handle_deleter = (
+ var._handle_deleter if not self._in_graph_mode # pylint: disable=protected-access
+ else None)
+ # pylint: enable=super-init-not-called
+
+ @property
+ def name(self):
+ return self._var.name
+
+ def value(self):
+ return self._read_variable_op()
+
+ def read_value(self):
+ return self._read_variable_op()
+
+ def _read_variable_op(self):
+ with ops.colocate_with(self._handle):
+ res = gen_resource_variable_ops.read_variable_op(self._handle,
+ self._dtype)
+ if self._read_dtype != self._dtype:
+ return math_ops.cast(res, self._read_dtype)
+ else:
+ return res
+
+ def set_shape(self, shape):
+ self._shape = shape
+ self._cached_shape_as_list = None
+
+ @property
+ def op(self):
+ """The op for this variable."""
+ return self._var.op
+
+ @property
+ def read_dtype(self):
+ """The dtype of the returned tensor when reading the var."""
+ return self._read_dtype
+
+ def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
+ del name
+ dtype = dtype or self.read_dtype
+ if dtype != self.read_dtype or as_ref:
+ return NotImplemented
+ else:
+ res = self.value()
+ return res
+
+ def _should_act_as_resource_variable(self):
+ """To pass resource_variable_ops.is_resource_variable check."""
+ pass
+
# Register a conversion function which reads the value of the variable,
# allowing instances of the class to be used as tensors.
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py
index 82a044a0d4..70805fd572 100644
--- a/tensorflow/python/ops/rnn_cell_impl.py
+++ b/tensorflow/python/ops/rnn_cell_impl.py
@@ -47,7 +47,6 @@ from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.checkpointable import base as checkpointable
-from tensorflow.python.training.checkpointable import tracking as checkpointable_tracking
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
@@ -55,16 +54,6 @@ from tensorflow.python.util.tf_export import tf_export
_BIAS_VARIABLE_NAME = "bias"
_WEIGHTS_VARIABLE_NAME = "kernel"
-
-# TODO(jblespiau): Remove this function when we are sure there are no longer
-# any usage (even if protected, it is being used). Prefer assert_like_rnncell.
-def _like_rnncell(cell):
- """Checks that a given object is an RNNCell by using duck typing."""
- conditions = [hasattr(cell, "output_size"), hasattr(cell, "state_size"),
- hasattr(cell, "zero_state"), callable(cell)]
- return all(conditions)
-
-
# This can be used with self.assertRaisesRegexp for assert_like_rnncell.
ASSERT_LIKE_RNNCELL_ERROR_REGEXP = "is not an RNNCell"
@@ -1330,48 +1319,3 @@ class MultiRNNCell(RNNCell):
array_ops.concat(new_states, 1))
return cur_inp, new_states
-
-
-class _SlimRNNCell(RNNCell, checkpointable_tracking.NotCheckpointable):
- """A simple wrapper for slim.rnn_cells."""
-
- def __init__(self, cell_fn):
- """Create a SlimRNNCell from a cell_fn.
-
- Args:
- cell_fn: a function which takes (inputs, state, scope) and produces the
- outputs and the new_state. Additionally when called with inputs=None and
- state=None it should return (initial_outputs, initial_state).
-
- Raises:
- TypeError: if cell_fn is not callable
- ValueError: if cell_fn cannot produce a valid initial state.
- """
- if not callable(cell_fn):
- raise TypeError("cell_fn %s needs to be callable", cell_fn)
- self._cell_fn = cell_fn
- self._cell_name = cell_fn.func.__name__
- init_output, init_state = self._cell_fn(None, None)
- output_shape = init_output.get_shape()
- state_shape = init_state.get_shape()
- self._output_size = output_shape.with_rank(2)[1].value
- self._state_size = state_shape.with_rank(2)[1].value
- if self._output_size is None:
- raise ValueError("Initial output created by %s has invalid shape %s" %
- (self._cell_name, output_shape))
- if self._state_size is None:
- raise ValueError("Initial state created by %s has invalid shape %s" %
- (self._cell_name, state_shape))
-
- @property
- def state_size(self):
- return self._state_size
-
- @property
- def output_size(self):
- return self._output_size
-
- def __call__(self, inputs, state, scope=None):
- scope = scope or self._cell_name
- output, state = self._cell_fn(inputs, state, scope=scope)
- return output, state
diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py
index 1e3f662ff3..af103d3cc7 100644
--- a/tensorflow/python/ops/script_ops.py
+++ b/tensorflow/python/ops/script_ops.py
@@ -130,7 +130,7 @@ class FuncRegistry(object):
def __init__(self):
self._lock = threading.Lock()
self._unique_id = 0 # GUARDED_BY(self._lock)
- # Only store weakrefs to the funtions. The strong reference is stored in
+ # Only store weakrefs to the functions. The strong reference is stored in
# the graph.
self._funcs = weakref.WeakValueDictionary()
diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py
index 8cb6a0537e..2c93cf72c7 100644
--- a/tensorflow/python/ops/state_ops.py
+++ b/tensorflow/python/ops/state_ops.py
@@ -19,7 +19,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import gen_resource_variable_ops
@@ -124,9 +123,7 @@ def is_variable_initialized(ref, name=None):
if ref.dtype._is_ref_dtype:
return gen_state_ops.is_variable_initialized(ref=ref, name=name)
# Handle resource variables.
- if context.executing_eagerly() or ref.op.type == "VarHandleOp":
- return gen_resource_variable_ops.var_is_initialized_op(ref.handle,
- name=name)
+ return ref.is_initialized(name=name)
@tf_export("assign_sub")
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index 1e06bf07d5..77f67c18ee 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -255,7 +255,7 @@ class _VariableStore(object):
initializer=None,
regularizer=None,
reuse=None,
- trainable=True,
+ trainable=None,
collections=None,
caching_device=None,
partitioner=None,
@@ -300,6 +300,8 @@ class _VariableStore(object):
forced to be False.
trainable: If `True` also add the variable to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ `trainable` defaults to `True` unless `synchronization` is
+ set to `ON_READ`.
collections: List of graph collections keys to add the `Variable` to.
Defaults to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`).
caching_device: Optional device string or function describing where the
@@ -341,7 +343,8 @@ class _VariableStore(object):
aggregated. Accepted values are constants defined in the class
@{tf.VariableSynchronization}. By default the synchronization is set to
`AUTO` and the current `DistributionStrategy` chooses
- when to synchronize.
+ when to synchronize. If `synchronization` is set to `ON_READ`,
+ `trainable` must not be set to `True`.
aggregation: Indicates how a distributed variable will be aggregated.
Accepted values are constants defined in the class
@{tf.VariableAggregation}.
@@ -404,7 +407,7 @@ class _VariableStore(object):
initializer=None,
regularizer=None,
reuse=None,
- trainable=True,
+ trainable=None,
collections=None,
caching_device=None,
partitioner=None,
@@ -477,6 +480,10 @@ class _VariableStore(object):
synchronization=synchronization,
aggregation=aggregation)
+ # Set trainable value based on synchronization value.
+ trainable = _get_trainable_value(
+ synchronization=synchronization, trainable=trainable)
+
if custom_getter is not None:
# Handle backwards compatibility with getter arguments that were added
# to the API after users started writing custom getters.
@@ -519,11 +526,20 @@ class _VariableStore(object):
synchronization=synchronization,
aggregation=aggregation)
- def _get_partitioned_variable(
- self, name, partitioner, shape=None, dtype=dtypes.float32,
- initializer=None, regularizer=None, reuse=None,
- trainable=True, collections=None, caching_device=None,
- validate_shape=True, use_resource=None, constraint=None):
+ def _get_partitioned_variable(self,
+ name,
+ partitioner,
+ shape=None,
+ dtype=dtypes.float32,
+ initializer=None,
+ regularizer=None,
+ reuse=None,
+ trainable=None,
+ collections=None,
+ caching_device=None,
+ validate_shape=True,
+ use_resource=None,
+ constraint=None):
"""Gets or creates a sharded variable list with these parameters.
The `partitioner` must be a callable that accepts a fully defined
@@ -773,7 +789,7 @@ class _VariableStore(object):
regularizer=None,
partition_info=None,
reuse=None,
- trainable=True,
+ trainable=None,
collections=None,
caching_device=None,
validate_shape=True,
@@ -1136,7 +1152,7 @@ class VariableScope(object):
initializer=None,
regularizer=None,
reuse=None,
- trainable=True,
+ trainable=None,
collections=None,
caching_device=None,
partitioner=None,
@@ -1207,7 +1223,7 @@ class VariableScope(object):
dtype=None,
initializer=None,
regularizer=None,
- trainable=True,
+ trainable=None,
collections=None,
caching_device=None,
partitioner=None,
@@ -1422,7 +1438,7 @@ def get_variable(name,
dtype=None,
initializer=None,
regularizer=None,
- trainable=True,
+ trainable=None,
collections=None,
caching_device=None,
partitioner=None,
@@ -2334,11 +2350,28 @@ def _compute_slice_dim_and_shape(full_shape, slicing):
return slice_dim, slice_shape
+def _get_trainable_value(synchronization, trainable):
+ """Computes the trainable value based on the given arguments."""
+ if synchronization == VariableSynchronization.ON_READ:
+ if trainable:
+ raise ValueError(
+ "Synchronization value can be set to "
+ "VariableSynchronization.ON_READ only for non-trainable variables. "
+ "You have specified trainable=True and "
+ "synchronization=VariableSynchronization.ON_READ.")
+ else:
+ # Set trainable to be false when variable is to be synced on read.
+ trainable = False
+ elif trainable is None:
+ trainable = True
+ return trainable
+
+
def default_variable_creator(next_creator=None, **kwargs):
"""Default variable creator."""
assert next_creator is None
initial_value = kwargs.get("initial_value", None)
- trainable = kwargs.get("trainable", True)
+ trainable = kwargs.get("trainable", None)
collections = kwargs.get("collections", None)
validate_shape = kwargs.get("validate_shape", True)
caching_device = kwargs.get("caching_device", None)
@@ -2347,10 +2380,10 @@ def default_variable_creator(next_creator=None, **kwargs):
constraint = kwargs.get("constraint", None)
use_resource = kwargs.get("use_resource", None)
- # Enforce `ON_READ` variables to be not trainable.
+ # Set trainable value based on synchronization value.
synchronization = kwargs.get("synchronization", VariableSynchronization.AUTO)
- if synchronization == VariableSynchronization.ON_READ:
- trainable = False
+ trainable = _get_trainable_value(
+ synchronization=synchronization, trainable=trainable)
if use_resource is None:
use_resource = get_variable_scope().use_resource
@@ -2379,7 +2412,7 @@ def _make_getter(captured_getter, captured_previous):
def variable(initial_value=None,
- trainable=True,
+ trainable=None,
collections=None,
validate_shape=True,
caching_device=None,
@@ -2441,6 +2474,8 @@ def variable_creator_scope(variable_creator):
trainable: If `True`, the default, also adds the variable to the graph
collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
the default list of variables to use by the `Optimizer` classes.
+ `trainable` defaults to `True` unless `synchronization` is
+ set to `ON_READ`.
collections: List of graph collections keys. The new variable is added to
these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
validate_shape: If `False`, allows the variable to be initialized with a
@@ -2463,7 +2498,8 @@ def variable_creator_scope(variable_creator):
aggregated. Accepted values are constants defined in the class
@{tf.VariableSynchronization}. By default the synchronization is set to
`AUTO` and the current `DistributionStrategy` chooses
- when to synchronize.
+ when to synchronize. If `synchronization` is set to `ON_READ`,
+ `trainable` must not be set to `True`.
aggregation: Indicates how a distributed variable will be aggregated.
Accepted values are constants defined in the class
@{tf.VariableAggregation}.
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index 9a09cdaa52..d3b8da6d2a 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -1404,6 +1404,10 @@ class PartitionedVariable(object):
def dtype(self):
return self._dtype
+ @property
+ def shape(self):
+ return self.get_shape()
+
def get_shape(self):
return self._shape
diff --git a/tensorflow/python/platform/benchmark.py b/tensorflow/python/platform/benchmark.py
index eba2baaf6f..fa17b17d10 100644
--- a/tensorflow/python/platform/benchmark.py
+++ b/tensorflow/python/platform/benchmark.py
@@ -66,11 +66,11 @@ def _global_report_benchmark(
if not isinstance(extras, dict):
raise TypeError("extras must be a dict")
- logging.info("Benchmark [%s] iters: %d, wall_time: %g, cpu_time: %g,"
- "throughput: %g %s", name, iters if iters is not None else -1,
- wall_time if wall_time is not None else -1, cpu_time if
- cpu_time is not None else -1, throughput if
- throughput is not None else -1, str(extras) if extras else "")
+ logging.info("Benchmark [%s] iters: %d, wall_time: %g, cpu_time: %g,"
+ "throughput: %g %s", name, iters if iters is not None else -1,
+ wall_time if wall_time is not None else -1, cpu_time if
+ cpu_time is not None else -1, throughput if
+ throughput is not None else -1, str(extras) if extras else "")
entries = test_log_pb2.BenchmarkEntries()
entry = entries.entry.add()
diff --git a/tensorflow/python/platform/self_check.py b/tensorflow/python/platform/self_check.py
index 966a094e55..844ae99918 100644
--- a/tensorflow/python/platform/self_check.py
+++ b/tensorflow/python/platform/self_check.py
@@ -78,7 +78,7 @@ def preload_check():
"Could not find %r. TensorFlow requires that this DLL be "
"installed in a directory that is named in your %%PATH%% "
"environment variable. Download and install CUDA %s from "
- "this URL: https://developer.nvidia.com/cuda-toolkit"
+ "this URL: https://developer.nvidia.com/cuda-90-download-archive"
% (build_info.cudart_dll_name, build_info.cuda_version_number))
if hasattr(build_info, "cudnn_dll_name") and hasattr(
diff --git a/tensorflow/tools/api/generator/BUILD b/tensorflow/python/tools/api/generator/BUILD
index 8c760e6f52..223d1281ba 100644
--- a/tensorflow/tools/api/generator/BUILD
+++ b/tensorflow/python/tools/api/generator/BUILD
@@ -3,8 +3,9 @@
licenses(["notice"]) # Apache 2.0
-load("//tensorflow/tools/api/generator:api_gen.bzl", "ESTIMATOR_API_INIT_FILES")
-load("//tensorflow/tools/api/generator:api_gen.bzl", "TENSORFLOW_API_INIT_FILES")
+load("//tensorflow:tensorflow.bzl", "py_test")
+load("//tensorflow/python/tools/api/generator:api_gen.bzl", "ESTIMATOR_API_INIT_FILES")
+load("//tensorflow/python/tools/api/generator:api_gen.bzl", "TENSORFLOW_API_INIT_FILES")
exports_files(
[
@@ -13,6 +14,18 @@ exports_files(
],
)
+py_binary(
+ name = "create_python_api",
+ srcs = ["//tensorflow/python/tools/api/generator:create_python_api.py"],
+ main = "//tensorflow/python/tools/api/generator:create_python_api.py",
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/python:no_contrib",
+ "//tensorflow/python/tools/api/generator:doc_srcs",
+ ],
+)
+
py_library(
name = "doc_srcs",
srcs = ["doc_srcs.py"],
diff --git a/tensorflow/tools/api/generator/api_gen.bzl b/tensorflow/python/tools/api/generator/api_gen.bzl
index d746b5d3e4..2a32e8a893 100644
--- a/tensorflow/tools/api/generator/api_gen.bzl
+++ b/tensorflow/python/tools/api/generator/api_gen.bzl
@@ -102,36 +102,41 @@ ESTIMATOR_API_INIT_FILES = [
# END GENERATED ESTIMATOR FILES
]
-# Creates a genrule that generates a directory structure with __init__.py
-# files that import all exported modules (i.e. modules with tf_export
-# decorators).
-#
-# Args:
-# name: name of genrule to create.
-# output_files: List of __init__.py files that should be generated.
-# This list should include file name for every module exported using
-# tf_export. For e.g. if an op is decorated with
-# @tf_export('module1.module2', 'module3'). Then, output_files should
-# include module1/module2/__init__.py and module3/__init__.py.
-# root_init_template: Python init file that should be used as template for
-# root __init__.py file. "# API IMPORTS PLACEHOLDER" comment inside this
-# template will be replaced with root imports collected by this genrule.
-# srcs: genrule sources. If passing root_init_template, the template file
-# must be included in sources.
-# api_name: Name of the project that you want to generate API files for
-# (e.g. "tensorflow" or "estimator").
-# package: Python package containing the @tf_export decorators you want to
-# process
-# package_dep: Python library target containing your package.
-
def gen_api_init_files(
name,
output_files = TENSORFLOW_API_INIT_FILES,
root_init_template = None,
srcs = [],
api_name = "tensorflow",
+ api_version = 2,
package = "tensorflow.python",
- package_dep = "//tensorflow/python:no_contrib"):
+ package_dep = "//tensorflow/python:no_contrib",
+ output_package = "tensorflow"):
+ """Creates API directory structure and __init__.py files.
+
+ Creates a genrule that generates a directory structure with __init__.py
+ files that import all exported modules (i.e. modules with tf_export
+ decorators).
+
+ Args:
+ name: name of genrule to create.
+ output_files: List of __init__.py files that should be generated.
+ This list should include file name for every module exported using
+ tf_export. For e.g. if an op is decorated with
+ @tf_export('module1.module2', 'module3'). Then, output_files should
+ include module1/module2/__init__.py and module3/__init__.py.
+ root_init_template: Python init file that should be used as template for
+ root __init__.py file. "# API IMPORTS PLACEHOLDER" comment inside this
+ template will be replaced with root imports collected by this genrule.
+ srcs: genrule sources. If passing root_init_template, the template file
+ must be included in sources.
+ api_name: Name of the project that you want to generate API files for
+ (e.g. "tensorflow" or "estimator").
+ api_version: TensorFlow API version to generate. Must be either 1 or 2.
+ package: Python package containing the @tf_export decorators you want to
+ process
+ package_dep: Python library target containing your package.
+ """
root_init_template_flag = ""
if root_init_template:
root_init_template_flag = "--root_init_template=$(location " + root_init_template + ")"
@@ -139,13 +144,13 @@ def gen_api_init_files(
api_gen_binary_target = "create_" + package + "_api"
native.py_binary(
name = "create_" + package + "_api",
- srcs = ["//tensorflow/tools/api/generator:create_python_api.py"],
- main = "//tensorflow/tools/api/generator:create_python_api.py",
+ srcs = ["//tensorflow/python/tools/api/generator:create_python_api.py"],
+ main = "//tensorflow/python/tools/api/generator:create_python_api.py",
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
package_dep,
- "//tensorflow/tools/api/generator:doc_srcs",
+ "//tensorflow/python/tools/api/generator:doc_srcs",
],
)
@@ -154,7 +159,9 @@ def gen_api_init_files(
outs = output_files,
cmd = (
"$(location :" + api_gen_binary_target + ") " +
- root_init_template_flag + " --apidir=$(@D) --apiname=" + api_name + " --package=" + package + " $(OUTS)"),
+ root_init_template_flag + " --apidir=$(@D) --apiname=" +
+ api_name + " --apiversion=" + str(api_version) + " --package=" + package +
+ " --output_package=" + output_package + " $(OUTS)"),
srcs = srcs,
tools = [":" + api_gen_binary_target ],
visibility = ["//tensorflow:__pkg__"],
diff --git a/tensorflow/tools/api/generator/create_python_api.py b/tensorflow/python/tools/api/generator/create_python_api.py
index 48d7dcd09e..863c922216 100644
--- a/tensorflow/tools/api/generator/create_python_api.py
+++ b/tensorflow/python/tools/api/generator/create_python_api.py
@@ -24,11 +24,12 @@ import importlib
import os
import sys
+from tensorflow.python.tools.api.generator import doc_srcs
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_export
-from tensorflow.tools.api.generator import doc_srcs
API_ATTRS = tf_export.API_ATTRS
+API_ATTRS_V1 = tf_export.API_ATTRS_V1
_DEFAULT_PACKAGE = 'tensorflow.python'
_GENFILES_DIR_SUFFIX = 'genfiles/'
@@ -38,14 +39,14 @@ _SYMBOLS_TO_SKIP_EXPLICITLY = {
'tensorflow.python.platform.flags.FLAGS'
}
_GENERATED_FILE_HEADER = """# This file is MACHINE GENERATED! Do not edit.
-# Generated by: tensorflow/tools/api/generator/create_python_api.py script.
+# Generated by: tensorflow/python/tools/api/generator/create_python_api.py script.
\"\"\"%s
\"\"\"
from __future__ import print_function
"""
-_GENERATED_FILE_FOOTER = "\n\ndel print_function\n"
+_GENERATED_FILE_FOOTER = '\n\ndel print_function\n'
class SymbolExposedTwiceError(Exception):
@@ -159,13 +160,16 @@ __all__.remove('print_function')
return module_text_map
-def get_api_init_text(package, api_name):
+def get_api_init_text(package, output_package, api_name, api_version):
"""Get a map from destination module to __init__.py code for that module.
Args:
package: Base python package containing python with target tf_export
decorators.
+ output_package: Base output python package where generated API will
+ be added.
api_name: API you want to generate (e.g. `tensorflow` or `estimator`).
+ api_version: API version you want to generate (`v1` or `v2`).
Returns:
A dictionary where
@@ -173,6 +177,12 @@ def get_api_init_text(package, api_name):
value: (string) text that should be in __init__.py files for
corresponding modules.
"""
+ if api_version == 1:
+ names_attr = API_ATTRS_V1[api_name].names
+ constants_attr = API_ATTRS_V1[api_name].constants
+ else:
+ names_attr = API_ATTRS[api_name].names
+ constants_attr = API_ATTRS[api_name].constants
module_code_builder = _ModuleInitCodeBuilder()
# Traverse over everything imported above. Specifically,
@@ -193,7 +203,7 @@ def get_api_init_text(package, api_name):
attr = getattr(module, module_contents_name)
# If attr is _tf_api_constants attribute, then add the constants.
- if module_contents_name == API_ATTRS[api_name].constants:
+ if module_contents_name == constants_attr:
for exports, value in attr:
for export in exports:
names = export.split('.')
@@ -205,9 +215,8 @@ def get_api_init_text(package, api_name):
_, attr = tf_decorator.unwrap(attr)
# If attr is a symbol with _tf_api_names attribute, then
# add import for it.
- if (hasattr(attr, '__dict__') and
- API_ATTRS[api_name].names in attr.__dict__):
- for export in getattr(attr, API_ATTRS[api_name].names): # pylint: disable=protected-access
+ if (hasattr(attr, '__dict__') and names_attr in attr.__dict__):
+ for export in getattr(attr, names_attr): # pylint: disable=protected-access
names = export.split('.')
dest_module = '.'.join(names[:-1])
module_code_builder.add_import(
@@ -218,7 +227,6 @@ def get_api_init_text(package, api_name):
# For e.g. if we import 'foo.bar.Value'. Then, we also
# import 'bar' in 'foo'.
imported_modules = set(module_code_builder.module_imports.keys())
- import_from = '.'
for module in imported_modules:
if not module:
continue
@@ -229,6 +237,9 @@ def get_api_init_text(package, api_name):
if submodule_index > 0:
parent_module += ('.' + module_split[submodule_index-1] if parent_module
else module_split[submodule_index-1])
+ import_from = output_package
+ if submodule_index > 0:
+ import_from += '.' + '.'.join(module_split[:submodule_index])
module_code_builder.add_import(
-1, parent_module, import_from,
module_split[submodule_index], module_split[submodule_index])
@@ -294,7 +305,8 @@ def get_module_docstring(module_name, package, api_name):
def create_api_files(
- output_files, package, root_init_template, output_dir, api_name):
+ output_files, package, root_init_template, output_dir, output_package,
+ api_name, api_version):
"""Creates __init__.py files for the Python API.
Args:
@@ -306,7 +318,9 @@ def create_api_files(
"#API IMPORTS PLACEHOLDER" comment in the template file will be replaced
with imports.
output_dir: output API root directory.
+ output_package: Base output package where generated API will be added.
api_name: API you want to generate (e.g. `tensorflow` or `estimator`).
+ api_version: API version to generate (`v1` or `v2`).
Raises:
ValueError: if an output file is not under api/ directory,
@@ -323,7 +337,8 @@ def create_api_files(
os.makedirs(os.path.dirname(file_path))
open(file_path, 'a').close()
- module_text_map = get_api_init_text(package, api_name)
+ module_text_map = get_api_init_text(
+ package, output_package, api_name, api_version)
# Add imports to output files.
missing_output_files = []
@@ -381,6 +396,13 @@ def main():
'--apiname', required=True, type=str,
choices=API_ATTRS.keys(),
help='The API you want to generate.')
+ parser.add_argument(
+ '--apiversion', default=2, type=int,
+ choices=[1, 2],
+ help='The API version you want to generate.')
+ parser.add_argument(
+ '--output_package', default='tensorflow', type=str,
+ help='Root output package.')
args = parser.parse_args()
@@ -395,7 +417,8 @@ def main():
# Populate `sys.modules` with modules containing tf_export().
importlib.import_module(args.package)
create_api_files(outputs, args.package, args.root_init_template,
- args.apidir, args.apiname)
+ args.apidir, args.output_package, args.apiname,
+ args.apiversion)
if __name__ == '__main__':
diff --git a/tensorflow/tools/api/generator/create_python_api_test.py b/tensorflow/python/tools/api/generator/create_python_api_test.py
index 651ec9d040..a565a49d96 100644
--- a/tensorflow/tools/api/generator/create_python_api_test.py
+++ b/tensorflow/python/tools/api/generator/create_python_api_test.py
@@ -22,8 +22,8 @@ import imp
import sys
from tensorflow.python.platform import test
+from tensorflow.python.tools.api.generator import create_python_api
from tensorflow.python.util.tf_export import tf_export
-from tensorflow.tools.api.generator import create_python_api
@tf_export('test_op', 'test_op1')
@@ -58,7 +58,8 @@ class CreatePythonApiTest(test.TestCase):
def testFunctionImportIsAdded(self):
imports = create_python_api.get_api_init_text(
package=create_python_api._DEFAULT_PACKAGE,
- api_name='tensorflow')
+ output_package='tensorflow',
+ api_name='tensorflow', api_version=1)
expected_import = (
'from tensorflow.python.test_module '
'import test_op as test_op1')
@@ -75,7 +76,8 @@ class CreatePythonApiTest(test.TestCase):
def testClassImportIsAdded(self):
imports = create_python_api.get_api_init_text(
package=create_python_api._DEFAULT_PACKAGE,
- api_name='tensorflow')
+ output_package='tensorflow',
+ api_name='tensorflow', api_version=2)
expected_import = ('from tensorflow.python.test_module '
'import TestClass')
self.assertTrue(
@@ -85,7 +87,8 @@ class CreatePythonApiTest(test.TestCase):
def testConstantIsAdded(self):
imports = create_python_api.get_api_init_text(
package=create_python_api._DEFAULT_PACKAGE,
- api_name='tensorflow')
+ output_package='tensorflow',
+ api_name='tensorflow', api_version=1)
expected = ('from tensorflow.python.test_module '
'import _TEST_CONSTANT')
self.assertTrue(expected in str(imports),
diff --git a/tensorflow/tools/api/generator/doc_srcs.py b/tensorflow/python/tools/api/generator/doc_srcs.py
index ad1988494d..ad1988494d 100644
--- a/tensorflow/tools/api/generator/doc_srcs.py
+++ b/tensorflow/python/tools/api/generator/doc_srcs.py
diff --git a/tensorflow/tools/api/generator/doc_srcs_test.py b/tensorflow/python/tools/api/generator/doc_srcs_test.py
index dbff904abe..481d9874a4 100644
--- a/tensorflow/tools/api/generator/doc_srcs_test.py
+++ b/tensorflow/python/tools/api/generator/doc_srcs_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
-"""Tests for tensorflow.tools.api.generator.doc_srcs."""
+"""Tests for tensorflow.python.tools.api.generator.doc_srcs."""
from __future__ import absolute_import
from __future__ import division
@@ -23,7 +23,7 @@ import importlib
import sys
from tensorflow.python.platform import test
-from tensorflow.tools.api.generator import doc_srcs
+from tensorflow.python.tools.api.generator import doc_srcs
FLAGS = None
diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py
index d33fd7376a..c719045c7f 100644
--- a/tensorflow/python/training/distribute.py
+++ b/tensorflow/python/training/distribute.py
@@ -614,48 +614,6 @@ class DistributionStrategy(object):
# Note: should support "colocate_with" argument.
raise NotImplementedError("must be implemented in descendants")
- def tower_local_var_scope(self, aggregation):
- """Inside this scope, new variables will not be mirrored.
-
- There will still be one component variable per tower, but there is
- no requirement that they stay in sync. Instead, when saving them
- or calling `read_var()`, we use the value that results when
- calling `reduce()` on all the towers' variables.
-
- Note: tower-local implies not trainable. Instead, it is expected
- that each tower will directly update (using `assign_add()` or
- whatever) its local variable instance but only the aggregated
- value (accessible using `read_var()`) will be exported from the
- model. When it is acceptable to only aggregate on export, we
- greatly reduce communication overhead by using tower-local
- variables.
-
- Note: All component variables will be initialized to the same
- value, using the initialization expression from the first tower.
- The values will match even if the initialization expression uses
- random numbers.
-
- Args:
- aggregation: Indicates how a variable will be aggregated. Accepted values
- are @{tf.VariableAggregation.SUM}, @{tf.VariableAggregation.MEAN}.
-
- Returns:
- A context manager.
- """
- # TODO(psv): Remove this after adding support for synchronization and
- # aggregation parameters in get_variable() and mirrored strategy.
- def create_tower_local_variable(next_creator, *args, **kwargs):
- _require_distribution_strategy_scope(self)
- kwargs["use_resource"] = True
-
- # Set synchronization to be ON_READ for tower local variables.
- kwargs["synchronization"] = variable_scope.VariableSynchronization.ON_READ
- kwargs["aggregation"] = aggregation
- return next_creator(*args, **kwargs)
-
- _require_distribution_strategy_scope(self)
- return variable_scope.variable_creator_scope(create_tower_local_variable)
-
def read_var(self, v):
"""Reads the value of a variable.
@@ -1103,10 +1061,6 @@ class TowerContext(object):
finally:
_pop_per_thread_mode()
- def tower_local_var_scope(self, aggregation):
- """Alias for distribution_strategy.tower_local_var_scope()."""
- return self._distribution_strategy.tower_local_var_scope(aggregation)
-
@property
def is_single_tower(self):
"""Returns whether there is a single tower or multiple."""
@@ -1158,16 +1112,6 @@ class _DefaultDistributionStrategy(DistributionStrategy):
return _CurrentDistributionContext(
self, variable_scope.variable_creator_scope(creator))
- def tower_local_var_scope(self, aggregation):
- """Does not set to resource variables."""
- def create_tower_local_variable(next_creator, *args, **kwargs):
- _require_distribution_strategy_scope(self)
- kwargs["trainable"] = False
- return next_creator(*args, **kwargs)
-
- _require_distribution_strategy_scope(self)
- return variable_scope.variable_creator_scope(create_tower_local_variable)
-
def colocate_vars_with(self, colocate_with_variable):
"""Does not require `self.scope`."""
_require_distribution_strategy_scope(self)
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index 971ed5c8b5..f75db08059 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -77,9 +77,10 @@ def _deduplicate_indexed_slices(values, indices):
def _var_key(var):
- if context.executing_eagerly():
- return var._unique_id # pylint: disable=protected-access
- return (var.op.graph, var.op.name)
+ # TODO(ashankar): Consolidate handling for eager and graph
+ if hasattr(var, "op"):
+ return (var.op.graph, var.op.name)
+ return var._unique_id # pylint: disable=protected-access
class _OptimizableVariable(object):
diff --git a/tensorflow/python/training/quantize_training.i b/tensorflow/python/training/quantize_training.i
index fb5e47efa0..54d6789616 100644
--- a/tensorflow/python/training/quantize_training.i
+++ b/tensorflow/python/training/quantize_training.i
@@ -73,6 +73,8 @@ def do_quantize_training_on_graphdef(input_graph, num_bits):
do_quantize_training_on_graphdef._tf_api_names = [
'train.do_quantize_training_on_graphdef']
+do_quantize_training_on_graphdef._tf_api_names_v1 = [
+ 'train.do_quantize_training_on_graphdef']
%}
%unignoreall
diff --git a/tensorflow/python/util/deprecation.py b/tensorflow/python/util/deprecation.py
index 376be39978..c8ed2b715d 100644
--- a/tensorflow/python/util/deprecation.py
+++ b/tensorflow/python/util/deprecation.py
@@ -87,6 +87,27 @@ def _call_location(outer=False):
return '%s:%d' % (entry[1], entry[2])
+def _wrap_decorator(wrapped_function):
+ """Indicate that one function wraps another.
+
+ This decorator wraps a function using `tf_decorator.make_decorator`
+ so that doc generation scripts can pick up original function
+ signature.
+ It would be better to use @functools.wrap decorator, but it would
+ not update function signature to match wrapped function in Python 2.
+
+ Args:
+ wrapped_function: The function that decorated function wraps.
+
+ Returns:
+ Function that accepts wrapper function as an argument and returns
+ `TFDecorator` instance.
+ """
+ def wrapper(wrapper_func):
+ return tf_decorator.make_decorator(wrapped_function, wrapper_func)
+ return wrapper
+
+
def deprecated_alias(deprecated_name, name, func_or_class, warn_once=True):
"""Deprecate a symbol in favor of a new name with identical semantics.
@@ -144,7 +165,7 @@ def deprecated_alias(deprecated_name, name, func_or_class, warn_once=True):
if tf_inspect.isclass(func_or_class):
# Make a new class with __init__ wrapped in a warning.
- class NewClass(func_or_class): # pylint: disable=missing-docstring
+ class _NewClass(func_or_class): # pylint: disable=missing-docstring
__doc__ = decorator_utils.add_notice_to_docstring(
func_or_class.__doc__, 'Please use %s instead.' % name,
'DEPRECATED CLASS',
@@ -153,27 +174,28 @@ def deprecated_alias(deprecated_name, name, func_or_class, warn_once=True):
__name__ = func_or_class.__name__
__module__ = _call_location(outer=True)
+ @_wrap_decorator(func_or_class.__init__)
def __init__(self, *args, **kwargs):
- if hasattr(NewClass.__init__, '__func__'):
+ if hasattr(_NewClass.__init__, '__func__'):
# Python 2
- NewClass.__init__.__func__.__doc__ = func_or_class.__init__.__doc__
+ _NewClass.__init__.__func__.__doc__ = func_or_class.__init__.__doc__
else:
# Python 3
- NewClass.__init__.__doc__ = func_or_class.__init__.__doc__
+ _NewClass.__init__.__doc__ = func_or_class.__init__.__doc__
if _PRINT_DEPRECATION_WARNINGS:
# We're making the alias as we speak. The original may have other
# aliases, so we cannot use it to check for whether it's already been
# warned about.
- if NewClass.__init__ not in _PRINTED_WARNING:
+ if _NewClass.__init__ not in _PRINTED_WARNING:
if warn_once:
- _PRINTED_WARNING[NewClass.__init__] = True
+ _PRINTED_WARNING[_NewClass.__init__] = True
logging.warning(
'From %s: The name %s is deprecated. Please use %s instead.\n',
_call_location(), deprecated_name, name)
- super(NewClass, self).__init__(*args, **kwargs)
+ super(_NewClass, self).__init__(*args, **kwargs)
- return NewClass
+ return _NewClass
else:
decorator_utils.validate_callable(func_or_class, 'deprecated')
diff --git a/tensorflow/python/util/deprecation_test.py b/tensorflow/python/util/deprecation_test.py
index bdd0bc48d2..1ea695e4d6 100644
--- a/tensorflow/python/util/deprecation_test.py
+++ b/tensorflow/python/util/deprecation_test.py
@@ -22,6 +22,7 @@ from __future__ import print_function
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import deprecation
+from tensorflow.python.util import tf_inspect
class DeprecatedAliasTest(test.TestCase):
@@ -73,6 +74,11 @@ class DeprecatedAliasTest(test.TestCase):
self.assertEqual(["test", "deprecated", "deprecated again"],
MyClass.init_args)
+ # Check __init__ signature matches for doc generation.
+ self.assertEqual(
+ tf_inspect.getfullargspec(MyClass.__init__),
+ tf_inspect.getfullargspec(deprecated_cls.__init__))
+
class DeprecationTest(test.TestCase):
diff --git a/tensorflow/python/util/py_checkpoint_reader.i b/tensorflow/python/util/py_checkpoint_reader.i
index 8004898cbc..1c73f7f06f 100644
--- a/tensorflow/python/util/py_checkpoint_reader.i
+++ b/tensorflow/python/util/py_checkpoint_reader.i
@@ -166,6 +166,7 @@ def NewCheckpointReader(filepattern):
return CheckpointReader(compat.as_bytes(filepattern), status)
NewCheckpointReader._tf_api_names = ['train.NewCheckpointReader']
+NewCheckpointReader._tf_api_names_v1 = ['train.NewCheckpointReader']
%}
%include "tensorflow/c/checkpoint_reader.h"
diff --git a/tensorflow/python/util/stat_summarizer.i b/tensorflow/python/util/stat_summarizer.i
index 73fa85494b..a5a7984d91 100644
--- a/tensorflow/python/util/stat_summarizer.i
+++ b/tensorflow/python/util/stat_summarizer.i
@@ -27,8 +27,8 @@ limitations under the License.
%ignoreall
-%unignore _NewStatSummarizer;
-%unignore _DeleteStatSummarizer;
+%unignore NewStatSummarizer;
+%unignore DeleteStatSummarizer;
%unignore tensorflow;
%unignore tensorflow::StatSummarizer;
%unignore tensorflow::StatSummarizer::StatSummarizer;
@@ -43,20 +43,20 @@ limitations under the License.
// TODO(ashankar): Remove the unused argument from the API.
%{
-tensorflow::StatSummarizer* _NewStatSummarizer(
+tensorflow::StatSummarizer* NewStatSummarizer(
const string& unused) {
return new tensorflow::StatSummarizer(tensorflow::StatSummarizerOptions());
}
%}
%{
-void _DeleteStatSummarizer(tensorflow::StatSummarizer* ss) {
+void DeleteStatSummarizer(tensorflow::StatSummarizer* ss) {
delete ss;
}
%}
-tensorflow::StatSummarizer* _NewStatSummarizer(const string& unused);
-void _DeleteStatSummarizer(tensorflow::StatSummarizer* ss);
+tensorflow::StatSummarizer* NewStatSummarizer(const string& unused);
+void DeleteStatSummarizer(tensorflow::StatSummarizer* ss);
%extend tensorflow::StatSummarizer {
void ProcessStepStatsStr(const string& step_stats_str) {
@@ -76,16 +76,3 @@ void _DeleteStatSummarizer(tensorflow::StatSummarizer* ss);
%include "tensorflow/core/util/stat_summarizer_options.h"
%include "tensorflow/core/util/stat_summarizer.h"
%unignoreall
-
-%insert("python") %{
-
-# Wrapping NewStatSummarizer and DeletStatSummarizer because
-# SWIG-generated functions are built-in functions and do not support
-# setting _tf_api_names attribute.
-
-def NewStatSummarizer(unused):
- return _NewStatSummarizer(unused)
-
-def DeleteStatSummarizer(stat_summarizer):
- _DeleteStatSummarizer(stat_summarizer)
-%}
diff --git a/tensorflow/python/util/tf_export.py b/tensorflow/python/util/tf_export.py
index e154ffb68a..c362d588ab 100644
--- a/tensorflow/python/util/tf_export.py
+++ b/tensorflow/python/util/tf_export.py
@@ -63,6 +63,15 @@ API_ATTRS = {
'_estimator_api_constants')
}
+API_ATTRS_V1 = {
+ TENSORFLOW_API_NAME: _Attributes(
+ '_tf_api_names_v1',
+ '_tf_api_constants_v1'),
+ ESTIMATOR_API_NAME: _Attributes(
+ '_estimator_api_names_v1',
+ '_estimator_api_constants_v1')
+}
+
class SymbolAlreadyExposedError(Exception):
"""Raised when adding API names to symbol that already has API names."""
@@ -78,13 +87,16 @@ class api_export(object): # pylint: disable=invalid-name
Args:
*args: API names in dot delimited format.
**kwargs: Optional keyed arguments.
- overrides: List of symbols that this is overriding
+ v1: Names for the TensorFlow V1 API. If not set, we will use V2 API
+ names both for TensorFlow V1 and V2 APIs.
+ overrides: List of symbols that this is overriding
(those overrided api exports will be removed). Note: passing overrides
has no effect on exporting a constant.
- api_name: Name of the API you want to generate (e.g. `tensorflow` or
+ api_name: Name of the API you want to generate (e.g. `tensorflow` or
`estimator`). Default is `tensorflow`.
"""
self._names = args
+ self._names_v1 = kwargs.get('v1', args)
self._api_name = kwargs.get('api_name', TENSORFLOW_API_NAME)
self._overrides = kwargs.get('overrides', [])
@@ -102,24 +114,27 @@ class api_export(object): # pylint: disable=invalid-name
and kwarg `allow_multiple_exports` not set.
"""
api_names_attr = API_ATTRS[self._api_name].names
-
+ api_names_attr_v1 = API_ATTRS_V1[self._api_name].names
# Undecorate overridden names
for f in self._overrides:
_, undecorated_f = tf_decorator.unwrap(f)
delattr(undecorated_f, api_names_attr)
+ delattr(undecorated_f, api_names_attr_v1)
_, undecorated_func = tf_decorator.unwrap(func)
+ self.set_attr(undecorated_func, api_names_attr, self._names)
+ self.set_attr(undecorated_func, api_names_attr_v1, self._names_v1)
+ return func
+ def set_attr(self, func, api_names_attr, names):
# Check for an existing api. We check if attribute name is in
# __dict__ instead of using hasattr to verify that subclasses have
# their own _tf_api_names as opposed to just inheriting it.
- if api_names_attr in undecorated_func.__dict__:
+ if api_names_attr in func.__dict__:
raise SymbolAlreadyExposedError(
'Symbol %s is already exposed as %s.' %
- (undecorated_func.__name__, getattr(
- undecorated_func, api_names_attr))) # pylint: disable=protected-access
- setattr(undecorated_func, api_names_attr, self._names)
- return func
+ (func.__name__, getattr(func, api_names_attr))) # pylint: disable=protected-access
+ setattr(func, api_names_attr, names)
def export_constant(self, module_name, name):
"""Store export information for constants/string literals.
@@ -140,12 +155,20 @@ class api_export(object): # pylint: disable=invalid-name
name: (string) Current constant name.
"""
module = sys.modules[module_name]
- if not hasattr(module, API_ATTRS[self._api_name].constants):
- setattr(module, API_ATTRS[self._api_name].constants, [])
+ api_constants_attr = API_ATTRS[self._api_name].constants
+ api_constants_attr_v1 = API_ATTRS_V1[self._api_name].constants
+
+ if not hasattr(module, api_constants_attr):
+ setattr(module, api_constants_attr, [])
# pylint: disable=protected-access
- getattr(module, API_ATTRS[self._api_name].constants).append(
+ getattr(module, api_constants_attr).append(
(self._names, name))
+ if not hasattr(module, api_constants_attr_v1):
+ setattr(module, api_constants_attr_v1, [])
+ getattr(module, api_constants_attr_v1).append(
+ (self._names_v1, name))
+
tf_export = functools.partial(api_export, api_name=TENSORFLOW_API_NAME)
estimator_export = functools.partial(tf_export, api_name=ESTIMATOR_API_NAME)
diff --git a/tensorflow/python/util/tf_export_test.py b/tensorflow/python/util/tf_export_test.py
index b9e26ecb33..4ae1dc55e0 100644
--- a/tensorflow/python/util/tf_export_test.py
+++ b/tensorflow/python/util/tf_export_test.py
@@ -60,6 +60,8 @@ class ValidateExportTest(test.TestCase):
for symbol in [_test_function, _test_function, TestClassA, TestClassB]:
if hasattr(symbol, '_tf_api_names'):
del symbol._tf_api_names
+ if hasattr(symbol, '_tf_api_names_v1'):
+ del symbol._tf_api_names_v1
def _CreateMockModule(self, name):
mock_module = self.MockModule(name)
diff --git a/tensorflow/python/util/tf_inspect.py b/tensorflow/python/util/tf_inspect.py
index fbd6561767..ec20998bdd 100644
--- a/tensorflow/python/util/tf_inspect.py
+++ b/tensorflow/python/util/tf_inspect.py
@@ -300,6 +300,16 @@ def getsource(object): # pylint: disable=redefined-builtin
return _inspect.getsource(tf_decorator.unwrap(object)[1])
+def getsourcefile(object): # pylint: disable=redefined-builtin
+ """TFDecorator-aware replacement for inspect.getsourcefile."""
+ return _inspect.getsourcefile(tf_decorator.unwrap(object)[1])
+
+
+def getsourcelines(object): # pylint: disable=redefined-builtin
+ """TFDecorator-aware replacement for inspect.getsourcelines."""
+ return _inspect.getsourcelines(tf_decorator.unwrap(object)[1])
+
+
def isbuiltin(object): # pylint: disable=redefined-builtin
"""TFDecorator-aware replacement for inspect.isbuiltin."""
return _inspect.isbuiltin(tf_decorator.unwrap(object)[1])
diff --git a/tensorflow/python/util/tf_inspect_test.py b/tensorflow/python/util/tf_inspect_test.py
index beaf350de1..2f6021c7d8 100644
--- a/tensorflow/python/util/tf_inspect_test.py
+++ b/tensorflow/python/util/tf_inspect_test.py
@@ -326,6 +326,18 @@ def test_decorated_function_with_defaults(a, b=2, c='Hello'):
self.assertEqual(
expected, tf_inspect.getsource(test_decorated_function_with_defaults))
+ def testGetSourceFile(self):
+ self.assertEqual(
+ __file__,
+ tf_inspect.getsourcefile(test_decorated_function_with_defaults))
+
+ def testGetSourceLines(self):
+ expected = inspect.getsourcelines(
+ test_decorated_function_with_defaults.decorated_target)
+ self.assertEqual(
+ expected,
+ tf_inspect.getsourcelines(test_decorated_function_with_defaults))
+
def testIsBuiltin(self):
self.assertEqual(
tf_inspect.isbuiltin(TestDecoratedClass),
diff --git a/tensorflow/python/util/tf_stack.py b/tensorflow/python/util/tf_stack.py
new file mode 100644
index 0000000000..dacc1ce83e
--- /dev/null
+++ b/tensorflow/python/util/tf_stack.py
@@ -0,0 +1,97 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functions used to extract and analyze stacks. Faster than Python libs."""
+# pylint: disable=g-bad-name
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import linecache
+import sys
+
+
+def extract_stack(extract_frame_info_fn=None):
+ """A lightweight, extensible re-implementation of traceback.extract_stack.
+
+ NOTE(mrry): traceback.extract_stack eagerly retrieves the line of code for
+ each stack frame using linecache, which results in an abundance of stat()
+ calls. This implementation does not retrieve the code, and any consumer
+ should apply _convert_stack to the result to obtain a traceback that can
+ be formatted etc. using traceback methods.
+
+ Args:
+ extract_frame_info_fn: Optional callable fn(stack_frame) applied to each
+ stack frame. This callable's return value is stored as the sixth (last)
+ element of the returned tuples. If not provided, the returned tuples
+ will have None as their sixth value.
+
+ Returns:
+ A list of 6-tuples
+ (filename, lineno, name, frame_globals, func_start_lineno, custom_info)
+ corresponding to the call stack of the current thread. The returned tuples
+ have the innermost stack frame at the end, unlike the Python inspect
+ module's stack() function.
+ """
+ default_fn = lambda f: None
+ extract_frame_info_fn = extract_frame_info_fn or default_fn
+ try:
+ raise ZeroDivisionError
+ except ZeroDivisionError:
+ f = sys.exc_info()[2].tb_frame.f_back
+ ret = []
+ while f is not None:
+ lineno = f.f_lineno
+ co = f.f_code
+ filename = co.co_filename
+ name = co.co_name
+ frame_globals = f.f_globals
+ func_start_lineno = co.co_firstlineno
+ frame_info = extract_frame_info_fn(f)
+ ret.append((filename, lineno, name, frame_globals, func_start_lineno,
+ frame_info))
+ f = f.f_back
+ ret.reverse()
+ return ret
+
+
+def convert_stack(stack, include_func_start_lineno=False):
+ """Converts a stack extracted using extract_stack() to a traceback stack.
+
+ Args:
+ stack: A list of n 5-tuples,
+ (filename, lineno, name, frame_globals, func_start_lineno).
+ include_func_start_lineno: True if function start line number should be
+ included as the 5th entry in return tuples.
+
+ Returns:
+ A list of n 4-tuples or 5-tuples
+ (filename, lineno, name, code, [optional: func_start_lineno]), where the
+ code tuple element is calculated from the corresponding elements of the
+ input tuple.
+ """
+ ret = []
+ for (filename, lineno, name, frame_globals, func_start_lineno,
+ unused_frame_info) in stack:
+ linecache.checkcache(filename)
+ line = linecache.getline(filename, lineno, frame_globals)
+ if line:
+ line = line.strip()
+ else:
+ line = None
+ if include_func_start_lineno:
+ ret.append((filename, lineno, name, line, func_start_lineno))
+ else:
+ ret.append((filename, lineno, name, line))
+ return ret
diff --git a/tensorflow/security/advisory/tfsa-2018-001.md b/tensorflow/security/advisory/tfsa-2018-001.md
index bb97543a21..1966789c84 100644
--- a/tensorflow/security/advisory/tfsa-2018-001.md
+++ b/tensorflow/security/advisory/tfsa-2018-001.md
@@ -22,7 +22,7 @@ TensorFlow 1.3.0, 1.3.1, 1.4.0, 1.4.1, 1.5.0, 1.5.1, 1.6.0
### Mitigation
We have patched the vulnerability in GitHub commit
-[49f73c55](https://github.com/tensorflow/tensorflow/commit/49f73c55d56edffebde4bca4a407ad69c1cae4333c55).
+[49f73c55](https://github.com/tensorflow/tensorflow/commit/49f73c55d56edffebde4bca4a407ad69c1cae433).
If users are running TensorFlow in production or on untrusted data, they are
encouraged to apply this patch.
diff --git a/tensorflow/security/index.md b/tensorflow/security/index.md
index ea39e17ab2..0f176151c2 100644
--- a/tensorflow/security/index.md
+++ b/tensorflow/security/index.md
@@ -4,7 +4,7 @@ We regularly publish security advisories about using TensorFlow.
*Note*: In conjunction with these security advisories, we strongly encourage
TensorFlow users to read and understand TensorFlow's security model as outlined
-in (https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md)[SECURITY.md].
+in [SECURITY.md](https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md).
| Advisory Number | Type | Versions affected | Reported by | Additional Information |
|-----------------|--------------------|:-----------------:|-----------------------|-----------------------------|
diff --git a/tensorflow/stream_executor/event.cc b/tensorflow/stream_executor/event.cc
index 50a6edd80b..52efe771bc 100644
--- a/tensorflow/stream_executor/event.cc
+++ b/tensorflow/stream_executor/event.cc
@@ -15,9 +15,9 @@ limitations under the License.
#include "tensorflow/stream_executor/event.h"
+#include "tensorflow/stream_executor/stream.h"
#include "tensorflow/stream_executor/stream_executor_internal.h"
#include "tensorflow/stream_executor/stream_executor_pimpl.h"
-#include "tensorflow/stream_executor/stream.h"
namespace stream_executor {
@@ -27,9 +27,12 @@ Event::Event(StreamExecutor* stream_exec)
stream_exec_->implementation()->CreateEventImplementation()) {}
Event::~Event() {
- auto status = stream_exec_->DeallocateEvent(this);
- if (!status.ok()) {
- LOG(ERROR) << status.error_message();
+ // Deal with nullptr implementation_, as this event may have been std::moved.
+ if (stream_exec_ && implementation_) {
+ auto status = stream_exec_->DeallocateEvent(this);
+ if (!status.ok()) {
+ LOG(ERROR) << status.error_message();
+ }
}
}
diff --git a/tensorflow/stream_executor/event.h b/tensorflow/stream_executor/event.h
index 1f37262c78..9cc87a7c12 100644
--- a/tensorflow/stream_executor/event.h
+++ b/tensorflow/stream_executor/event.h
@@ -61,6 +61,9 @@ class Event {
// Returns a pointer to the underlying platform-specific implementation.
internal::EventInterface* implementation() { return implementation_.get(); }
+ Event(Event&&) = default;
+ Event& operator=(Event&&) = default;
+
private:
friend class Stream;
diff --git a/tensorflow/stream_executor/host/host_gpu_executor.cc b/tensorflow/stream_executor/host/host_gpu_executor.cc
index c8a6297330..3cd97b3cf1 100644
--- a/tensorflow/stream_executor/host/host_gpu_executor.cc
+++ b/tensorflow/stream_executor/host/host_gpu_executor.cc
@@ -26,8 +26,6 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/statusor.h"
#include "tensorflow/stream_executor/plugin_registry.h"
-bool FLAGS_stream_executor_cpu_real_clock_rate = false;
-
namespace stream_executor {
namespace host {
@@ -95,7 +93,7 @@ bool HostExecutor::MemcpyDeviceToDevice(Stream *stream,
// the nature of the HostExecutor) memcpy on the stream (HostStream)
// associated with the HostExecutor.
AsHostStream(stream)->EnqueueTask(
- [src_mem, dst_mem, size]() { memcpy(dst_mem, src_mem, size); });
+ [src_mem, dst_mem, size]() { memcpy(src_mem, dst_mem, size); });
return true;
}
@@ -190,11 +188,8 @@ DeviceDescription *HostExecutor::PopulateDeviceDescription() const {
// doesn't result in thrashing or other badness? 4GiB chosen arbitrarily.
builder.set_device_memory_size(static_cast<uint64>(4) * 1024 * 1024 * 1024);
- float cycle_counter_frequency = 1e9;
- if (FLAGS_stream_executor_cpu_real_clock_rate) {
- cycle_counter_frequency = static_cast<float>(
- tensorflow::profile_utils::CpuUtils::GetCycleCounterFrequency());
- }
+ float cycle_counter_frequency = static_cast<float>(
+ tensorflow::profile_utils::CpuUtils::GetCycleCounterFrequency());
builder.set_clock_rate_ghz(cycle_counter_frequency / 1e9);
auto built = builder.Build();
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc
index 0cd0790a72..9369183133 100644
--- a/tensorflow/stream_executor/stream.cc
+++ b/tensorflow/stream_executor/stream.cc
@@ -5228,24 +5228,11 @@ port::Status Stream::BlockHostUntilDone() {
return status;
}
- port::Status first_error;
- {
- // Wait until all active sub-streams have done their tasks.
- mutex_lock lock(mu_);
- for (auto &stream : sub_streams_) {
- if (!stream.second) {
- first_error.Update(stream.first->BlockHostUntilDone());
- // Set this sub-stream as available.
- stream.second = true;
- }
- }
- }
-
temporary_memory_manager_.DeallocateFinalizedTemporaries();
- first_error.Update(parent_->BlockHostUntilDone(this));
- CheckError(first_error.ok());
- return first_error;
+ port::Status error = parent_->BlockHostUntilDone(this);
+ CheckError(error.ok());
+ return error;
}
} // namespace stream_executor
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index 3e3fbeb8f8..9259ebe869 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -28,7 +28,6 @@ load(
"//third_party/mkl_dnn:build_defs.bzl",
"if_mkl_open_source_only",
)
-
def register_extension_info(**kwargs):
pass
@@ -830,6 +829,9 @@ def tf_cc_test_mkl(srcs,
tags=[],
size="medium",
args=None):
+ # -fno-exceptions in nocopts breaks compilation if header modules are enabled.
+ disable_header_modules = ["-use_header_modules"]
+
for src in srcs:
native.cc_test(
name=src_to_test_name(src),
@@ -855,6 +857,7 @@ def tf_cc_test_mkl(srcs,
tags=tags,
size=size,
args=args,
+ features=disable_header_modules,
nocopts="-fno-exceptions")
@@ -989,16 +992,17 @@ register_extension_info(
label_regex_for_dep = "{extension_name}",
)
-def tf_kernel_library(name,
- prefix=None,
- srcs=None,
- gpu_srcs=None,
- hdrs=None,
- deps=None,
- alwayslink=1,
- copts=None,
- is_external=False,
- **kwargs):
+def tf_kernel_library(
+ name,
+ prefix = None,
+ srcs = None,
+ gpu_srcs = None,
+ hdrs = None,
+ deps = None,
+ alwayslink = 1,
+ copts = None,
+ is_external = False,
+ **kwargs):
"""A rule to build a TensorFlow OpKernel.
May either specify srcs/hdrs or prefix. Similar to tf_cuda_library,
@@ -1028,6 +1032,7 @@ def tf_kernel_library(name,
deps = []
if not copts:
copts = []
+ textual_hdrs = []
copts = copts + tf_copts(is_external=is_external)
if prefix:
if native.glob([prefix + "*.cu.cc"], exclude=["*test*"]):
@@ -1038,8 +1043,13 @@ def tf_kernel_library(name,
srcs = srcs + native.glob(
[prefix + "*.cc"], exclude=[prefix + "*test*", prefix + "*.cu.cc"])
hdrs = hdrs + native.glob(
- [prefix + "*.h"], exclude=[prefix + "*test*", prefix + "*.cu.h"])
-
+ [prefix + "*.h"],
+ exclude = [prefix + "*test*", prefix + "*.cu.h", prefix + "*impl.h"],
+ )
+ textual_hdrs = native.glob(
+ [prefix + "*impl.h"],
+ exclude = [prefix + "*test*", prefix + "*.cu.h"],
+ )
cuda_deps = [clean_dep("//tensorflow/core:gpu_lib")]
if gpu_srcs:
for gpu_src in gpu_srcs:
@@ -1053,6 +1063,7 @@ def tf_kernel_library(name,
name=name,
srcs=srcs,
hdrs=hdrs,
+ textual_hdrs = textual_hdrs,
copts=copts,
cuda_deps=cuda_deps,
linkstatic=1, # Needed since alwayslink is broken in bazel b/27630669
@@ -1086,6 +1097,9 @@ def tf_mkl_kernel_library(name,
hdrs = hdrs + native.glob(
[prefix + "*.h"])
+ # -fno-exceptions in nocopts breaks compilation if header modules are enabled.
+ disable_header_modules = ["-use_header_modules"]
+
native.cc_library(
name=name,
srcs=if_mkl(srcs),
@@ -1093,7 +1107,8 @@ def tf_mkl_kernel_library(name,
deps=deps,
alwayslink=alwayslink,
copts=copts,
- nocopts=nocopts
+ nocopts=nocopts,
+ features = disable_header_modules
)
register_extension_info(
diff --git a/tensorflow/tools/api/golden/tensorflow.-variable-scope.pbtxt b/tensorflow/tools/api/golden/tensorflow.-variable-scope.pbtxt
index ec1f72453f..c13eb7b8bb 100644
--- a/tensorflow/tools/api/golden/tensorflow.-variable-scope.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-variable-scope.pbtxt
@@ -56,7 +56,7 @@ tf_class {
}
member_method {
name: "get_variable"
- argspec: "args=[\'self\', \'var_store\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'reuse\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'custom_getter\', \'constraint\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'var_store\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'reuse\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'custom_getter\', \'constraint\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "global_variables"
diff --git a/tensorflow/tools/api/golden/tensorflow.image.pbtxt b/tensorflow/tools/api/golden/tensorflow.image.pbtxt
index e89b4dbffd..6ec3aba775 100644
--- a/tensorflow/tools/api/golden/tensorflow.image.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.image.pbtxt
@@ -121,6 +121,10 @@ tf_module {
argspec: "args=[\'boxes\', \'scores\', \'max_output_size\', \'iou_threshold\', \'score_threshold\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'None\'], "
}
member_method {
+ name: "non_max_suppression_overlaps"
+ argspec: "args=[\'overlaps\', \'scores\', \'max_output_size\', \'overlap_threshold\', \'score_threshold\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'None\'], "
+ }
+ member_method {
name: "pad_to_bounding_box"
argspec: "args=[\'image\', \'offset_height\', \'offset_width\', \'target_height\', \'target_width\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt
index 11cdd6f0b5..40e82b18b6 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt
@@ -119,7 +119,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt
index 4afad3e4df..8295905975 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt
@@ -124,7 +124,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.initializers.pbtxt
index 14a667870d..8645e54302 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.initializers.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.initializers.pbtxt
@@ -90,11 +90,11 @@ tf_module {
}
member_method {
name: "glorot_normal"
- argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
}
member_method {
name: "glorot_uniform"
- argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
}
member_method {
name: "he_normal"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-activation.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-activation.pbtxt
index 2bf973debb..86e328888e 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-activation.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-activation.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-activity-regularization.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-activity-regularization.pbtxt
index 03f20e72c2..b0ed545781 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-activity-regularization.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-activity-regularization.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-add.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-add.pbtxt
index 4b46b8d15a..42f98ed03d 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-add.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-add.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-alpha-dropout.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-alpha-dropout.pbtxt
index d8a1c76fd0..000898a4be 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-alpha-dropout.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-alpha-dropout.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling1-d.pbtxt
index 622926bc4b..380b49f99c 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling1-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling2-d.pbtxt
index 82100d8e09..82db5e6137 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling2-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling3-d.pbtxt
index 408061077c..b6ff688ec3 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling3-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling3-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average.pbtxt
index a3c8031104..b41290f8b0 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool1-d.pbtxt
index e2dfaca29f..88a033e61f 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool1-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool2-d.pbtxt
index 4f068d2066..c1b9b96044 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool2-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool3-d.pbtxt
index b8c261a743..f59f7727a3 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool3-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool3-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-batch-normalization.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-batch-normalization.pbtxt
index 4ccd6cace6..7d3744ed92 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-batch-normalization.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-batch-normalization.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt
index 2790e5fd85..3fd4ccdab2 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt
@@ -107,7 +107,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-concatenate.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-concatenate.pbtxt
index b1326bd0e6..ba21b50be4 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-concatenate.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-concatenate.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt
index e3ac3dbf28..46f9fa2bbb 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt
@@ -188,7 +188,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv1-d.pbtxt
index 1117a695a3..c3ad326589 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv1-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d-transpose.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d-transpose.pbtxt
index b9de142142..fd9eb43066 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d-transpose.pbtxt
@@ -100,7 +100,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d.pbtxt
index deb535e06e..40d61688f2 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d-transpose.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d-transpose.pbtxt
index 9a9a223fba..b8c227d725 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d-transpose.pbtxt
@@ -100,7 +100,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d.pbtxt
index 1c59b0bdf6..095d35e574 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution1-d.pbtxt
index 30cf5489f4..8f99961198 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution1-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt
index 0ec69508d5..96d522a016 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt
@@ -100,7 +100,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d.pbtxt
index 4cd8928403..de2824dab4 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt
index 4b4912496d..1d563241d8 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt
@@ -100,7 +100,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d.pbtxt
index d0ad9cf567..c87e52c537 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping1-d.pbtxt
index 98cff95a7f..dccf5523e3 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping1-d.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping2-d.pbtxt
index 2357498b46..7ac4116d92 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping2-d.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping3-d.pbtxt
index 3324cbff30..024f72705d 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping3-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping3-d.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cu-d-n-n-g-r-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cu-d-n-n-g-r-u.pbtxt
index 6c81823654..4e0233331b 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cu-d-n-n-g-r-u.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cu-d-n-n-g-r-u.pbtxt
@@ -108,7 +108,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cu-d-n-n-l-s-t-m.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cu-d-n-n-l-s-t-m.pbtxt
index 487e04fd07..32d46ce8f3 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cu-d-n-n-l-s-t-m.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cu-d-n-n-l-s-t-m.pbtxt
@@ -108,7 +108,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dense.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-dense.pbtxt
index 137e7cced4..858486c725 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dense.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-dense.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt
index 7161665d25..f65d750926 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt
@@ -100,7 +100,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dot.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-dot.pbtxt
index 24affa2481..2e71ef503d 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dot.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-dot.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dropout.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-dropout.pbtxt
index 7ba19a4269..42533bcd21 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dropout.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-dropout.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-e-l-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-e-l-u.pbtxt
index 503aa9162c..b5df169417 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-e-l-u.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-e-l-u.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-embedding.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-embedding.pbtxt
index 1737e590a2..0ea17919a9 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-embedding.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-embedding.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-flatten.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-flatten.pbtxt
index 021d024dc2..a33248bc00 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-flatten.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-flatten.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt
index 65387008bf..4ba21a25cd 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt
index 4f791acf05..a7a570418e 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt
@@ -171,7 +171,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-dropout.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-dropout.pbtxt
index abc30e54e0..763bc23113 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-dropout.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-dropout.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-noise.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-noise.pbtxt
index 20791bb448..3c50a3d7f2 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-noise.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-noise.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt
index 449a91d873..ac78bdafad 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt
index bb361e1297..275282d9d2 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt
index e564bf3216..0e31e6058b 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt
index 4cb9cc3ec8..aacd0b1791 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt
index 5ed52b88ae..c236548663 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt
index f4559d29d7..6b9c0290aa 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool1-d.pbtxt
index 64e2d061e2..0d7b2211e6 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool1-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool2-d.pbtxt
index 3372ad6453..d080ad6aed 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool2-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool3-d.pbtxt
index 08a6860bcd..fcb0a109da 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool3-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool3-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt
index 22c9eab64f..1d0e22abd0 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt
index 74c405ba9b..653c9f547b 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt
index 39f6f98193..cdbaf82cf6 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-input-layer.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-input-layer.pbtxt
index 7b25e80b6b..230c5e9034 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-input-layer.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-input-layer.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt
index 3619b8bfc4..511456e740 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt
index 8ef3d71dd8..4a3492ebd6 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt
@@ -171,7 +171,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-lambda.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-lambda.pbtxt
index ecbaa9ce2c..5d05cf689f 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-lambda.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-lambda.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-layer.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-layer.pbtxt
index 9b90db1e5e..7efa29be77 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-layer.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-layer.pbtxt
@@ -97,7 +97,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-leaky-re-l-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-leaky-re-l-u.pbtxt
index 3c60eaab7f..0ca8e0b52c 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-leaky-re-l-u.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-leaky-re-l-u.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected1-d.pbtxt
index 3dac1ff342..f754fa1da8 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected1-d.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected2-d.pbtxt
index 7f1b5db4d3..c9516b8f07 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected2-d.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-masking.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-masking.pbtxt
index b3e31000f3..850ecff974 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-masking.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-masking.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool1-d.pbtxt
index bbd9d1b0dc..7c69e31f9a 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool1-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool2-d.pbtxt
index fe72beea80..fba42642d7 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool2-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool3-d.pbtxt
index e9bf57b2b0..9c277411ea 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool3-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool3-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling1-d.pbtxt
index 0eecc58a2b..7c2f6ccc8a 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling1-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling2-d.pbtxt
index 96785a7d85..802178dba6 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling2-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling3-d.pbtxt
index 42c46cccb3..e870dfe9ad 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling3-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling3-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-maximum.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-maximum.pbtxt
index ac816f68d4..c1337ce0cb 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-maximum.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-maximum.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-minimum.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-minimum.pbtxt
index 56e32e9d36..ed27a62765 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-minimum.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-minimum.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-multiply.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-multiply.pbtxt
index 9ae99563e9..b9f05cb3e5 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-multiply.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-multiply.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-p-re-l-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-p-re-l-u.pbtxt
index 815f3bc2d1..336d9f76fb 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-p-re-l-u.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-p-re-l-u.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-permute.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-permute.pbtxt
index e704992b4a..46282217e0 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-permute.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-permute.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt
index b3a58fa11e..42cd7e87ee 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt
@@ -102,7 +102,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-re-l-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-re-l-u.pbtxt
index f3a96ab895..c00fa79adf 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-re-l-u.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-re-l-u.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-repeat-vector.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-repeat-vector.pbtxt
index 78f464583b..9f094a877a 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-repeat-vector.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-repeat-vector.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-reshape.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-reshape.pbtxt
index 222344fd04..2f519a2438 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-reshape.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-reshape.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv1-d.pbtxt
index 55fddf576c..6b93116ba0 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv1-d.pbtxt
@@ -100,7 +100,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt
index 96314ce498..fd17115e27 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt
@@ -100,7 +100,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution1-d.pbtxt
index 88bdf99566..4b37a94478 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution1-d.pbtxt
@@ -100,7 +100,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt
index 6eeea7a8d1..5bdadca74a 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt
@@ -100,7 +100,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt
index 3050d46249..9dfda96fc8 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt
index dda4c9358b..7b7684ccd2 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt
@@ -159,7 +159,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-softmax.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-softmax.pbtxt
index cc6275158b..3b15407fca 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-softmax.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-softmax.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt
index 5eb7e75047..6d04415267 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt
index 500cb8c14e..04950654d5 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt
index 1113a7634f..c424e6dcc8 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt
index c4b9f93561..1160d2840f 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt
@@ -102,7 +102,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-subtract.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-subtract.pbtxt
index 35ad87ad5d..740a03367b 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-subtract.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-subtract.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt
index 282c98d79a..a08c583adb 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-time-distributed.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-time-distributed.pbtxt
index acab93706b..c1294fed0f 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-time-distributed.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-time-distributed.pbtxt
@@ -103,7 +103,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling1-d.pbtxt
index a5ec228a07..dc401d3ed0 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling1-d.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling2-d.pbtxt
index d8d8e0bfe9..4b5165ae97 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling2-d.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling3-d.pbtxt
index 97d6dc06fb..789af15fea 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling3-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling3-d.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-wrapper.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-wrapper.pbtxt
index ea9bb41b99..0536a7cee7 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-wrapper.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-wrapper.pbtxt
@@ -102,7 +102,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding1-d.pbtxt
index e6d1d2e089..8915353ec3 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding1-d.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding2-d.pbtxt
index f62017305f..6efb5ef15a 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding2-d.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding3-d.pbtxt
index 07a1fde5bd..4c33c5d0bf 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding3-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding3-d.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt
index 62aa929d32..85f7c2bfed 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt
@@ -119,7 +119,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt
index 93ecbbce9b..5211657414 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt
@@ -124,7 +124,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling1-d.pbtxt
index 11067058d5..c82e67526b 100644
--- a/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling1-d.pbtxt
@@ -109,7 +109,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling2-d.pbtxt
index 3259e706d7..1d031cb5f8 100644
--- a/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling2-d.pbtxt
@@ -109,7 +109,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling3-d.pbtxt
index e561f2f415..a8dda6655d 100644
--- a/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling3-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling3-d.pbtxt
@@ -109,7 +109,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-batch-normalization.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-batch-normalization.pbtxt
index 3124a35c78..97f65ed894 100644
--- a/tensorflow/tools/api/golden/tensorflow.layers.-batch-normalization.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.layers.-batch-normalization.pbtxt
@@ -108,7 +108,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-conv1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-conv1-d.pbtxt
index b5ec61255a..ccd9578f0d 100644
--- a/tensorflow/tools/api/golden/tensorflow.layers.-conv1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.layers.-conv1-d.pbtxt
@@ -109,7 +109,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-conv2-d-transpose.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-conv2-d-transpose.pbtxt
index b2c89ae66f..9cbb58d721 100644
--- a/tensorflow/tools/api/golden/tensorflow.layers.-conv2-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.layers.-conv2-d-transpose.pbtxt
@@ -110,7 +110,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-conv2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-conv2-d.pbtxt
index 9e4f4969dc..c75ea3911e 100644
--- a/tensorflow/tools/api/golden/tensorflow.layers.-conv2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.layers.-conv2-d.pbtxt
@@ -109,7 +109,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-conv3-d-transpose.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-conv3-d-transpose.pbtxt
index 9850e6d765..5dc834e514 100644
--- a/tensorflow/tools/api/golden/tensorflow.layers.-conv3-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.layers.-conv3-d-transpose.pbtxt
@@ -110,7 +110,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-conv3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-conv3-d.pbtxt
index be113826cc..96ab209874 100644
--- a/tensorflow/tools/api/golden/tensorflow.layers.-conv3-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.layers.-conv3-d.pbtxt
@@ -109,7 +109,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-dense.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-dense.pbtxt
index 0d951bf633..7e9656b352 100644
--- a/tensorflow/tools/api/golden/tensorflow.layers.-dense.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.layers.-dense.pbtxt
@@ -108,7 +108,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-dropout.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-dropout.pbtxt
index f1beeed9ef..e9a2269a6e 100644
--- a/tensorflow/tools/api/golden/tensorflow.layers.-dropout.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.layers.-dropout.pbtxt
@@ -108,7 +108,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-flatten.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-flatten.pbtxt
index b75a012811..7d2eaaab2a 100644
--- a/tensorflow/tools/api/golden/tensorflow.layers.-flatten.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.layers.-flatten.pbtxt
@@ -108,7 +108,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-layer.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-layer.pbtxt
index 80e0fb228b..8bc3eb26e9 100644
--- a/tensorflow/tools/api/golden/tensorflow.layers.-layer.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.layers.-layer.pbtxt
@@ -106,7 +106,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling1-d.pbtxt
index 50ff484d73..6a0dcce56a 100644
--- a/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling1-d.pbtxt
@@ -109,7 +109,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling2-d.pbtxt
index cea809744c..b6c84edf2a 100644
--- a/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling2-d.pbtxt
@@ -109,7 +109,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling3-d.pbtxt
index ab9e89554c..062a02fa59 100644
--- a/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling3-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling3-d.pbtxt
@@ -109,7 +109,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv1-d.pbtxt
index 4362568445..eaad0fb23e 100644
--- a/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv1-d.pbtxt
@@ -110,7 +110,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv2-d.pbtxt
index 3cad824cd3..ece28a8ce9 100644
--- a/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv2-d.pbtxt
@@ -110,7 +110,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt
index a8d9e120cb..c74773000a 100644
--- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt
@@ -117,7 +117,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt
index c039890e1f..d251f54806 100644
--- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt
@@ -117,7 +117,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt
index 62c393de34..8a63b49180 100644
--- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt
@@ -116,7 +116,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt
index f121ba7939..db1aae2757 100644
--- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt
@@ -120,7 +120,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt
index 4583dc32b2..d76eab7eb8 100644
--- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt
@@ -117,7 +117,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt
index 5016b6ac30..944db6ac93 100644
--- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt
@@ -117,7 +117,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
index 59623fc983..72b40cc9f7 100644
--- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
@@ -116,7 +116,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt
index e2ab5aaee9..a5c2b4aefd 100644
--- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt
@@ -115,7 +115,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt
index bd2a6d61f8..61d5f04b22 100644
--- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt
@@ -116,7 +116,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt
index bf2533e1b5..4f90743fec 100644
--- a/tensorflow/tools/api/golden/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.pbtxt
@@ -1174,7 +1174,7 @@ tf_module {
}
member_method {
name: "get_variable"
- argspec: "args=[\'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'custom_getter\', \'constraint\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'custom_getter\', \'constraint\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "get_variable_scope"
@@ -1561,10 +1561,6 @@ tf_module {
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
- name: "print"
- argspec: "args=[\'input_\', \'data\', \'message\', \'first_n\', \'summarize\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
name: "py_func"
argspec: "args=[\'func\', \'inp\', \'Tout\', \'stateful\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
}
diff --git a/tensorflow/tools/api/lib/python_object_to_proto_visitor.py b/tensorflow/tools/api/lib/python_object_to_proto_visitor.py
index 1cf330e702..3a48cf683c 100644
--- a/tensorflow/tools/api/lib/python_object_to_proto_visitor.py
+++ b/tensorflow/tools/api/lib/python_object_to_proto_visitor.py
@@ -88,6 +88,9 @@ def _SanitizedMRO(obj):
"""
return_list = []
for cls in tf_inspect.getmro(obj):
+ if cls.__name__ == '_NewClass':
+ # Ignore class created by @deprecated_alias decorator.
+ continue
str_repr = str(cls)
return_list.append(str_repr)
if 'tensorflow' not in str_repr:
diff --git a/tensorflow/tools/api/tests/api_compatibility_test.py b/tensorflow/tools/api/tests/api_compatibility_test.py
index 90375a794f..d1b34fb242 100644
--- a/tensorflow/tools/api/tests/api_compatibility_test.py
+++ b/tensorflow/tools/api/tests/api_compatibility_test.py
@@ -34,6 +34,13 @@ import sys
import unittest
import tensorflow as tf
+# pylint: disable=g-import-not-at-top
+try:
+ from tensorflow.compat import v1 as tf_v1
+ # We import compat.v1 as tf_v1 instead.
+ del tf.compat.v1
+except ImportError:
+ tf_v1 = None
from google.protobuf import message
from google.protobuf import text_format
@@ -46,6 +53,7 @@ from tensorflow.tools.api.lib import api_objects_pb2
from tensorflow.tools.api.lib import python_object_to_proto_visitor
from tensorflow.tools.common import public_api
from tensorflow.tools.common import traverse
+# pylint: enable=g-import-not-at-top
# FLAGS defined at the bottom:
@@ -215,25 +223,19 @@ class ApiCompatibilityTest(test.TestCase):
visitor.do_not_descend_map['tf'].append('contrib')
traverse.traverse(tf, visitor)
- @unittest.skipUnless(
- sys.version_info.major == 2,
- 'API compabitility test goldens are generated using python2.')
- def testAPIBackwardsCompatibility(self):
- # Extract all API stuff.
+ def checkBackwardsCompatibility(self, root, golden_file_pattern):
+ # Extract all API stuff.
visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor()
public_api_visitor = public_api.PublicAPIVisitor(visitor)
public_api_visitor.do_not_descend_map['tf'].append('contrib')
public_api_visitor.do_not_descend_map['tf.GPUOptions'] = ['Experimental']
- traverse.traverse(tf, public_api_visitor)
+ traverse.traverse(root, public_api_visitor)
proto_dict = visitor.GetProtos()
# Read all golden files.
- expression = os.path.join(
- resource_loader.get_root_dir_with_all_resources(),
- _KeyToFilePath('*'))
- golden_file_list = file_io.get_matching_files(expression)
+ golden_file_list = file_io.get_matching_files(golden_file_pattern)
def _ReadFileToProto(filename):
"""Read a filename, create a protobuf from its contents."""
@@ -254,6 +256,26 @@ class ApiCompatibilityTest(test.TestCase):
verbose=FLAGS.verbose_diffs,
update_goldens=FLAGS.update_goldens)
+ @unittest.skipUnless(
+ sys.version_info.major == 2,
+ 'API compabitility test goldens are generated using python2.')
+ def testAPIBackwardsCompatibility(self):
+ golden_file_pattern = os.path.join(
+ resource_loader.get_root_dir_with_all_resources(),
+ _KeyToFilePath('*'))
+ self.checkBackwardsCompatibility(tf, golden_file_pattern)
+
+ @unittest.skipUnless(
+ sys.version_info.major == 2,
+ 'API compabitility test goldens are generated using python2.')
+ def testAPIBackwardsCompatibilityV1(self):
+ if not tf_v1:
+ return
+ golden_file_pattern = os.path.join(
+ resource_loader.get_root_dir_with_all_resources(),
+ _KeyToFilePath('*'))
+ self.checkBackwardsCompatibility(tf_v1, golden_file_pattern)
+
if __name__ == '__main__':
parser = argparse.ArgumentParser()
diff --git a/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh b/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh
index d0816c92b7..75da9bb835 100755
--- a/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh
+++ b/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh
@@ -35,6 +35,30 @@ elif [[ ${BASH_VER_MAJOR} -eq 4 ]] && [[ ${BASH_VER_MINOR} -lt 2 ]]; then
exit 1
fi
+function is_absolute {
+ [[ "$1" = /* ]] || [[ "$1" =~ ^[a-zA-Z]:[/\\].* ]]
+}
+
+RUNFILES_MANIFEST_FILE="${TEST_SRCDIR}/MANIFEST"
+function rlocation() {
+ if is_absolute "$1" ; then
+ # If the file path is already fully specified, simply return it.
+ echo "$1"
+ elif [[ -e "$TEST_SRCDIR/$1" ]]; then
+ # If the file exists in the $TEST_SRCDIR then just use it.
+ echo "$TEST_SRCDIR/$1"
+ elif [[ -e "$RUNFILES_MANIFEST_FILE" ]]; then
+ # If a runfiles manifest file exists then use it.
+ echo "$(grep "^$1 " "$RUNFILES_MANIFEST_FILE" | sed 's/[^ ]* //')"
+ fi
+}
+
+TEST_BINARY="$(rlocation $TEST_WORKSPACE/${1#./})"
+shift
+
+# Make sure /var/lock exists, this may not be true under MSYS
+mkdir -p /var/lock
+
TF_GPU_COUNT=${TF_GPU_COUNT:-8}
for i in `seq 0 $((TF_GPU_COUNT-1))`; do
@@ -45,8 +69,8 @@ for i in `seq 0 $((TF_GPU_COUNT-1))`; do
# This export only works within the brackets, so it is isolated to one
# single command.
export CUDA_VISIBLE_DEVICES=$i
- echo "Running test $* on GPU $CUDA_VISIBLE_DEVICES"
- $@
+ echo "Running test $TEST_BINARY $* on GPU $CUDA_VISIBLE_DEVICES"
+ "$TEST_BINARY" $@
)
return_code=$?
flock -u "$lock_fd"
diff --git a/tensorflow/tools/ci_build/update_version.py b/tensorflow/tools/ci_build/update_version.py
index 642dde36a7..30c318a58f 100755
--- a/tensorflow/tools/ci_build/update_version.py
+++ b/tensorflow/tools/ci_build/update_version.py
@@ -248,16 +248,6 @@ def update_md_files(old_version, new_version):
replace_string_in_line(r"<version>%s<\/version>" % old_version,
"<version>%s</version>" % new_version, filepath)
- # Update any links to colab notebooks.
- def colab_url(version):
- version_string = "%s.%s.%s" % (version.major, version.minor, version.patch)
- prefix = "https://colab.research.google.com/github/tensorflow/models/blob/r"
- return prefix + version_string + "/"
-
- replace_string_in_line(
- colab_url(old_version), colab_url(new_version),
- "%s/docs_src/get_started/eager.md" % TF_SRC_DIR)
-
def major_minor_change(old_version, new_version):
"""Check if a major or minor change occurred."""
diff --git a/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh b/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh
index e10483e7fd..c03cbd9c66 100644
--- a/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh
+++ b/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh
@@ -23,10 +23,6 @@ function run_configure_for_gpu_build {
# Enable CUDA support
export TF_NEED_CUDA=1
- # TODO(pcloudy): Remove this after TensorFlow uses its own CRSOOTOOL
- # for GPU build on Windows
- export USE_MSVC_WRAPPER=1
-
yes "" | ./configure
}
diff --git a/tensorflow/tools/ci_build/windows/bazel/common_env.sh b/tensorflow/tools/ci_build/windows/bazel/common_env.sh
index 8a237e4e28..3af132217e 100644
--- a/tensorflow/tools/ci_build/windows/bazel/common_env.sh
+++ b/tensorflow/tools/ci_build/windows/bazel/common_env.sh
@@ -54,10 +54,10 @@ export PATH="/c/${PYTHON_BASE_PATH}/Scripts:$PATH"
export TF_CUDA_VERSION=${TF_CUDA_VERSION:-9.0}
export TF_CUDNN_VERSION=${TF_CUDNN_VERSION:-7.0}
export TF_CUDA_COMPUTE_CAPABILITIES=${TF_CUDA_COMPUTE_CAPABILITIES:-3.7}
-export CUDA_INSTALL_PATH=${CUDA_INSTALL_PATH:-"C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v${TF_CUDA_VERSION}"}
+export CUDA_TOOLKIT_PATH=${CUDA_TOOLKIT_PATH:-"C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v${TF_CUDA_VERSION}"}
export CUDNN_INSTALL_PATH=${CUDNN_INSTALL_PATH:-"C:/tools/cuda"}
# Add Cuda and Cudnn dll directories into PATH
-export PATH="$(cygpath -u "${CUDA_INSTALL_PATH}")/bin:$PATH"
-export PATH="$(cygpath -u "${CUDA_INSTALL_PATH}")/extras/CUPTI/libx64:$PATH"
+export PATH="$(cygpath -u "${CUDA_TOOLKIT_PATH}")/bin:$PATH"
+export PATH="$(cygpath -u "${CUDA_TOOLKIT_PATH}")/extras/CUPTI/libx64:$PATH"
export PATH="$(cygpath -u "${CUDNN_INSTALL_PATH}")/bin:$PATH"
diff --git a/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh b/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh
index fe3bce428f..36b2142d95 100644
--- a/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh
+++ b/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh
@@ -105,14 +105,18 @@ create_python_test_dir "${PY_TEST_DIR}"
PIP_NAME=$(ls ${PY_TEST_DIR}/tensorflow-*.whl)
reinstall_tensorflow_pip ${PIP_NAME}
+TF_GPU_COUNT=${TF_GPU_COUNT:-8}
+
# Define no_tensorflow_py_deps=true so that every py_test has no deps anymore,
# which will result testing system installed tensorflow
# GPU tests are very flaky when running concurrently, so set local_test_jobs=1
bazel test --announce_rc --config=opt -k --test_output=errors \
+ --test_env=TF_GPU_COUNT \
+ --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute \
--define=no_tensorflow_py_deps=true --test_lang_filters=py \
--test_tag_filters=-no_pip,-no_windows,-no_windows_gpu,-no_gpu,-no_pip_gpu,-no_oss \
--build_tag_filters=-no_pip,-no_windows,-no_windows_gpu,-no_gpu,-no_pip_gpu,-no_oss --build_tests_only \
- --local_test_jobs=1 --test_timeout="300,450,1200,3600" \
+ --local_test_jobs=$TF_GPU_COUNT --test_timeout="300,450,1200,3600" \
--flaky_test_attempts=3 \
//${PY_TEST_DIR}/tensorflow/python/... \
//${PY_TEST_DIR}/tensorflow/contrib/...
diff --git a/tensorflow/tools/compatibility/ast_edits.py b/tensorflow/tools/compatibility/ast_edits.py
new file mode 100644
index 0000000000..23cc4a21a9
--- /dev/null
+++ b/tensorflow/tools/compatibility/ast_edits.py
@@ -0,0 +1,502 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Upgrader for Python scripts according to an API change specification."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import ast
+import collections
+import os
+import shutil
+import sys
+import tempfile
+import traceback
+
+
+class APIChangeSpec(object):
+ """This class defines the transformations that need to happen.
+
+ This class must provide the following fields:
+
+ * `function_keyword_renames`: maps function names to a map of old -> new
+ argument names
+ * `function_renames`: maps function names to new function names
+ * `change_to_function`: a set of function names that have changed (for
+ notifications)
+ * `function_reorders`: maps functions whose argument order has changed to the
+ list of arguments in the new order
+ * `function_handle`: maps function names to custom handlers for the function
+
+ For an example, see `TFAPIChangeSpec`.
+ """
+
+
+class _FileEditTuple(
+ collections.namedtuple("_FileEditTuple",
+ ["comment", "line", "start", "old", "new"])):
+ """Each edit that is recorded by a _FileEditRecorder.
+
+ Fields:
+ comment: A description of the edit and why it was made.
+ line: The line number in the file where the edit occurs (1-indexed).
+ start: The line number in the file where the edit occurs (0-indexed).
+ old: text string to remove (this must match what was in file).
+ new: text string to add in place of `old`.
+ """
+
+ __slots__ = ()
+
+
+class _FileEditRecorder(object):
+ """Record changes that need to be done to the file."""
+
+ def __init__(self, filename):
+ # all edits are lists of chars
+ self._filename = filename
+
+ self._line_to_edit = collections.defaultdict(list)
+ self._errors = []
+
+ def process(self, text):
+ """Process a list of strings, each corresponding to the recorded changes.
+
+ Args:
+ text: A list of lines of text (assumed to contain newlines)
+ Returns:
+ A tuple of the modified text and a textual description of what is done.
+ Raises:
+ ValueError: if substitution source location does not have expected text.
+ """
+
+ change_report = ""
+
+ # Iterate of each line
+ for line, edits in self._line_to_edit.items():
+ offset = 0
+ # sort by column so that edits are processed in order in order to make
+ # indexing adjustments cumulative for changes that change the string
+ # length
+ edits.sort(key=lambda x: x.start)
+
+ # Extract each line to a list of characters, because mutable lists
+ # are editable, unlike immutable strings.
+ char_array = list(text[line - 1])
+
+ # Record a description of the change
+ change_report += "%r Line %d\n" % (self._filename, line)
+ change_report += "-" * 80 + "\n\n"
+ for e in edits:
+ change_report += "%s\n" % e.comment
+ change_report += "\n Old: %s" % (text[line - 1])
+
+ # Make underscore buffers for underlining where in the line the edit was
+ change_list = [" "] * len(text[line - 1])
+ change_list_new = [" "] * len(text[line - 1])
+
+ # Iterate for each edit
+ for e in edits:
+ # Create effective start, end by accounting for change in length due
+ # to previous edits
+ start_eff = e.start + offset
+ end_eff = start_eff + len(e.old)
+
+ # Make sure the edit is changing what it should be changing
+ old_actual = "".join(char_array[start_eff:end_eff])
+ if old_actual != e.old:
+ raise ValueError("Expected text %r but got %r" %
+ ("".join(e.old), "".join(old_actual)))
+ # Make the edit
+ char_array[start_eff:end_eff] = list(e.new)
+
+ # Create the underline highlighting of the before and after
+ change_list[e.start:e.start + len(e.old)] = "~" * len(e.old)
+ change_list_new[start_eff:end_eff] = "~" * len(e.new)
+
+ # Keep track of how to generate effective ranges
+ offset += len(e.new) - len(e.old)
+
+ # Finish the report comment
+ change_report += " %s\n" % "".join(change_list)
+ text[line - 1] = "".join(char_array)
+ change_report += " New: %s" % (text[line - 1])
+ change_report += " %s\n\n" % "".join(change_list_new)
+ return "".join(text), change_report, self._errors
+
+ def add(self, comment, line, start, old, new, error=None):
+ """Add a new change that is needed.
+
+ Args:
+ comment: A description of what was changed
+ line: Line number (1 indexed)
+ start: Column offset (0 indexed)
+ old: old text
+ new: new text
+ error: this "edit" is something that cannot be fixed automatically
+ Returns:
+ None
+ """
+
+ self._line_to_edit[line].append(
+ _FileEditTuple(comment, line, start, old, new))
+ if error:
+ self._errors.append("%s:%d: %s" % (self._filename, line, error))
+
+
+class _ASTCallVisitor(ast.NodeVisitor):
+ """AST Visitor that processes function calls.
+
+ Updates function calls from old API version to new API version using a given
+ change spec.
+ """
+
+ def __init__(self, filename, lines, api_change_spec):
+ self._filename = filename
+ self._file_edit = _FileEditRecorder(filename)
+ self._lines = lines
+ self._api_change_spec = api_change_spec
+
+ def process(self, lines):
+ return self._file_edit.process(lines)
+
+ def generic_visit(self, node):
+ ast.NodeVisitor.generic_visit(self, node)
+
+ def _rename_functions(self, node, full_name):
+ function_renames = self._api_change_spec.function_renames
+ try:
+ new_name = function_renames[full_name]
+ self._file_edit.add("Renamed function %r to %r" % (full_name, new_name),
+ node.lineno, node.col_offset, full_name, new_name)
+ except KeyError:
+ pass
+
+ def _get_attribute_full_path(self, node):
+ """Traverse an attribute to generate a full name e.g. tf.foo.bar.
+
+ Args:
+ node: A Node of type Attribute.
+
+ Returns:
+ a '.'-delimited full-name or None if the tree was not a simple form.
+ i.e. `foo()+b).bar` returns None, while `a.b.c` would return "a.b.c".
+ """
+ curr = node
+ items = []
+ while not isinstance(curr, ast.Name):
+ if not isinstance(curr, ast.Attribute):
+ return None
+ items.append(curr.attr)
+ curr = curr.value
+ items.append(curr.id)
+ return ".".join(reversed(items))
+
+ def _find_true_position(self, node):
+ """Return correct line number and column offset for a given node.
+
+ This is necessary mainly because ListComp's location reporting reports
+ the next token after the list comprehension list opening.
+
+ Args:
+ node: Node for which we wish to know the lineno and col_offset
+ """
+ import re
+ find_open = re.compile("^\s*(\\[).*$")
+ find_string_chars = re.compile("['\"]")
+
+ if isinstance(node, ast.ListComp):
+ # Strangely, ast.ListComp returns the col_offset of the first token
+ # after the '[' token which appears to be a bug. Workaround by
+ # explicitly finding the real start of the list comprehension.
+ line = node.lineno
+ col = node.col_offset
+ # loop over lines
+ while 1:
+ # Reverse the text to and regular expression search for whitespace
+ text = self._lines[line - 1]
+ reversed_preceding_text = text[:col][::-1]
+ # First find if a [ can be found with only whitespace between it and
+ # col.
+ m = find_open.match(reversed_preceding_text)
+ if m:
+ new_col_offset = col - m.start(1) - 1
+ return line, new_col_offset
+ else:
+ if (reversed_preceding_text == "" or
+ reversed_preceding_text.isspace()):
+ line = line - 1
+ prev_line = self._lines[line - 1]
+ # TODO(aselle):
+ # this is poor comment detection, but it is good enough for
+ # cases where the comment does not contain string literal starting/
+ # ending characters. If ast gave us start and end locations of the
+ # ast nodes rather than just start, we could use string literal
+ # node ranges to filter out spurious #'s that appear in string
+ # literals.
+ comment_start = prev_line.find("#")
+ if comment_start == -1:
+ col = len(prev_line) - 1
+ elif find_string_chars.search(prev_line[comment_start:]) is None:
+ col = comment_start
+ else:
+ return None, None
+ else:
+ return None, None
+ # Most other nodes return proper locations (with notably does not), but
+ # it is not possible to use that in an argument.
+ return node.lineno, node.col_offset
+
+ def visit_Call(self, node): # pylint: disable=invalid-name
+ """Handle visiting a call node in the AST.
+
+ Args:
+ node: Current Node
+ """
+
+ # Find a simple attribute name path e.g. "tf.foo.bar"
+ full_name = self._get_attribute_full_path(node.func)
+
+ # Make sure the func is marked as being part of a call
+ node.func.is_function_for_call = True
+
+ if full_name:
+ # Call special handlers
+ function_handles = self._api_change_spec.function_handle
+ if full_name in function_handles:
+ function_handles[full_name](self._file_edit, node)
+
+ # Examine any non-keyword argument and make it into a keyword argument
+ # if reordering required.
+ function_reorders = self._api_change_spec.function_reorders
+ function_keyword_renames = (
+ self._api_change_spec.function_keyword_renames)
+
+ if full_name in function_reorders:
+ reordered = function_reorders[full_name]
+ for idx, arg in enumerate(node.args):
+ lineno, col_offset = self._find_true_position(arg)
+ if lineno is None or col_offset is None:
+ self._file_edit.add(
+ "Failed to add keyword %r to reordered function %r" %
+ (reordered[idx], full_name),
+ arg.lineno,
+ arg.col_offset,
+ "",
+ "",
+ error="A necessary keyword argument failed to be inserted.")
+ else:
+ keyword_arg = reordered[idx]
+ if (full_name in function_keyword_renames and
+ keyword_arg in function_keyword_renames[full_name]):
+ keyword_arg = function_keyword_renames[full_name][keyword_arg]
+ self._file_edit.add("Added keyword %r to reordered function %r" %
+ (reordered[idx], full_name), lineno, col_offset,
+ "", keyword_arg + "=")
+
+ # Examine each keyword argument and convert it to the final renamed form
+ renamed_keywords = ({} if full_name not in function_keyword_renames else
+ function_keyword_renames[full_name])
+ for keyword in node.keywords:
+ argkey = keyword.arg
+ argval = keyword.value
+
+ if argkey in renamed_keywords:
+ argval_lineno, argval_col_offset = self._find_true_position(argval)
+ if argval_lineno is not None and argval_col_offset is not None:
+ # TODO(aselle): We should scan backward to find the start of the
+ # keyword key. Unfortunately ast does not give you the location of
+ # keyword keys, so we are forced to infer it from the keyword arg
+ # value.
+ key_start = argval_col_offset - len(argkey) - 1
+ key_end = key_start + len(argkey) + 1
+ if (self._lines[argval_lineno - 1][key_start:key_end] == argkey +
+ "="):
+ self._file_edit.add("Renamed keyword argument from %r to %r" %
+ (argkey,
+ renamed_keywords[argkey]), argval_lineno,
+ argval_col_offset - len(argkey) - 1,
+ argkey + "=", renamed_keywords[argkey] + "=")
+ continue
+ self._file_edit.add(
+ "Failed to rename keyword argument from %r to %r" %
+ (argkey, renamed_keywords[argkey]),
+ argval.lineno,
+ argval.col_offset - len(argkey) - 1,
+ "",
+ "",
+ error="Failed to find keyword lexographically. Fix manually.")
+
+ ast.NodeVisitor.generic_visit(self, node)
+
+ def visit_Attribute(self, node): # pylint: disable=invalid-name
+ """Handle bare Attributes i.e. [tf.foo, tf.bar].
+
+ Args:
+ node: Node that is of type ast.Attribute
+ """
+ full_name = self._get_attribute_full_path(node)
+ if full_name:
+ self._rename_functions(node, full_name)
+ if full_name in self._api_change_spec.change_to_function:
+ if not hasattr(node, "is_function_for_call"):
+ new_text = full_name + "()"
+ self._file_edit.add("Changed %r to %r" % (full_name, new_text),
+ node.lineno, node.col_offset, full_name, new_text)
+
+ ast.NodeVisitor.generic_visit(self, node)
+
+
+class ASTCodeUpgrader(object):
+ """Handles upgrading a set of Python files using a given API change spec."""
+
+ def __init__(self, api_change_spec):
+ if not isinstance(api_change_spec, APIChangeSpec):
+ raise TypeError("Must pass APIChangeSpec to ASTCodeUpgrader, got %s" %
+ type(api_change_spec))
+ self._api_change_spec = api_change_spec
+
+ def process_file(self, in_filename, out_filename):
+ """Process the given python file for incompatible changes.
+
+ Args:
+ in_filename: filename to parse
+ out_filename: output file to write to
+ Returns:
+ A tuple representing number of files processed, log of actions, errors
+ """
+
+ # Write to a temporary file, just in case we are doing an implace modify.
+ with open(in_filename, "r") as in_file, \
+ tempfile.NamedTemporaryFile("w", delete=False) as temp_file:
+ ret = self.process_opened_file(in_filename, in_file, out_filename,
+ temp_file)
+
+ shutil.move(temp_file.name, out_filename)
+ return ret
+
+ # Broad exceptions are required here because ast throws whatever it wants.
+ # pylint: disable=broad-except
+ def process_opened_file(self, in_filename, in_file, out_filename, out_file):
+ """Process the given python file for incompatible changes.
+
+ This function is split out to facilitate StringIO testing from
+ tf_upgrade_test.py.
+
+ Args:
+ in_filename: filename to parse
+ in_file: opened file (or StringIO)
+ out_filename: output file to write to
+ out_file: opened file (or StringIO)
+ Returns:
+ A tuple representing number of files processed, log of actions, errors
+ """
+ process_errors = []
+ text = "-" * 80 + "\n"
+ text += "Processing file %r\n outputting to %r\n" % (in_filename,
+ out_filename)
+ text += "-" * 80 + "\n\n"
+
+ parsed_ast = None
+ lines = in_file.readlines()
+ try:
+ parsed_ast = ast.parse("".join(lines))
+ except Exception:
+ text += "Failed to parse %r\n\n" % in_filename
+ text += traceback.format_exc()
+ if parsed_ast:
+ visitor = _ASTCallVisitor(in_filename, lines, self._api_change_spec)
+ visitor.visit(parsed_ast)
+ out_text, new_text, process_errors = visitor.process(lines)
+ text += new_text
+ if out_file:
+ out_file.write(out_text)
+ text += "\n"
+ return 1, text, process_errors
+
+ # pylint: enable=broad-except
+
+ def process_tree(self, root_directory, output_root_directory,
+ copy_other_files):
+ """Processes upgrades on an entire tree of python files in place.
+
+ Note that only Python files. If you have custom code in other languages,
+ you will need to manually upgrade those.
+
+ Args:
+ root_directory: Directory to walk and process.
+ output_root_directory: Directory to use as base.
+ copy_other_files: Copy files that are not touched by this converter.
+
+ Returns:
+ A tuple of files processed, the report string ofr all files, and errors
+ """
+
+ # make sure output directory doesn't exist
+ if output_root_directory and os.path.exists(output_root_directory):
+ print("Output directory %r must not already exist." %
+ (output_root_directory))
+ sys.exit(1)
+
+ # make sure output directory does not overlap with root_directory
+ norm_root = os.path.split(os.path.normpath(root_directory))
+ norm_output = os.path.split(os.path.normpath(output_root_directory))
+ if norm_root == norm_output:
+ print("Output directory %r same as input directory %r" %
+ (root_directory, output_root_directory))
+ sys.exit(1)
+
+ # Collect list of files to process (we do this to correctly handle if the
+ # user puts the output directory in some sub directory of the input dir)
+ files_to_process = []
+ files_to_copy = []
+ for dir_name, _, file_list in os.walk(root_directory):
+ py_files = [f for f in file_list if f.endswith(".py")]
+ copy_files = [f for f in file_list if not f.endswith(".py")]
+ for filename in py_files:
+ fullpath = os.path.join(dir_name, filename)
+ fullpath_output = os.path.join(output_root_directory,
+ os.path.relpath(fullpath,
+ root_directory))
+ files_to_process.append((fullpath, fullpath_output))
+ if copy_other_files:
+ for filename in copy_files:
+ fullpath = os.path.join(dir_name, filename)
+ fullpath_output = os.path.join(output_root_directory,
+ os.path.relpath(
+ fullpath, root_directory))
+ files_to_copy.append((fullpath, fullpath_output))
+
+ file_count = 0
+ tree_errors = []
+ report = ""
+ report += ("=" * 80) + "\n"
+ report += "Input tree: %r\n" % root_directory
+ report += ("=" * 80) + "\n"
+
+ for input_path, output_path in files_to_process:
+ output_directory = os.path.dirname(output_path)
+ if not os.path.isdir(output_directory):
+ os.makedirs(output_directory)
+ file_count += 1
+ _, l_report, l_errors = self.process_file(input_path, output_path)
+ tree_errors += l_errors
+ report += l_report
+ for input_path, output_path in files_to_copy:
+ output_directory = os.path.dirname(output_path)
+ if not os.path.isdir(output_directory):
+ os.makedirs(output_directory)
+ shutil.copy(input_path, output_path)
+ return file_count, report, tree_errors
diff --git a/tensorflow/tools/docker/Dockerfile.devel b/tensorflow/tools/docker/Dockerfile.devel
index 57a491255e..fd94d64268 100644
--- a/tensorflow/tools/docker/Dockerfile.devel
+++ b/tensorflow/tools/docker/Dockerfile.devel
@@ -63,7 +63,7 @@ RUN echo "startup --batch" >>/etc/bazel.bazelrc
RUN echo "build --spawn_strategy=standalone --genrule_strategy=standalone" \
>>/etc/bazel.bazelrc
# Install the most recent bazel release.
-ENV BAZEL_VERSION 0.11.0
+ENV BAZEL_VERSION 0.14.1
WORKDIR /
RUN mkdir /bazel && \
cd /bazel && \
diff --git a/tensorflow/tools/docker/Dockerfile.devel-cpu-mkl b/tensorflow/tools/docker/Dockerfile.devel-cpu-mkl
new file mode 100644
index 0000000000..6796ad70e5
--- /dev/null
+++ b/tensorflow/tools/docker/Dockerfile.devel-cpu-mkl
@@ -0,0 +1,83 @@
+FROM tensorflow/tensorflow:latest-devel
+
+LABEL maintainer="Clayne Robison<clayne.b.robison@intel.com>"
+
+# These arguments are parameterized. Use --build-args to override.
+ARG TF_BRANCH=r1.9
+ARG WHL_DIR=/whl
+
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ golang \
+ vim \
+ emacs \
+ && \
+ apt-get clean && \
+ rm -rf /var/lib/apt/lists/*
+
+RUN pip --no-cache-dir install --upgrade \
+ pip setuptools
+
+RUN pip --no-cache-dir install wheel
+
+# Download and build TensorFlow.
+WORKDIR /
+RUN rm -rf tensorflow && \
+ git clone https://github.com/tensorflow/tensorflow.git && \
+ cd tensorflow && \
+ git checkout ${TF_BRANCH}
+WORKDIR /tensorflow
+
+# Configure the build for CPU with MKL by accepting default build options and
+# setting library locations
+ENV CI_BUILD_PYTHON=python \
+ LD_LIBRARY_PATH=${LD_LIBRARY_PATH} \
+ PYTHON_BIN_PATH=/usr/bin/python \
+ PYTHON_LIB_PATH=/usr/local/lib/python2.7/dist-packages \
+ CC_OPT_FLAGS='-march=native' \
+ TF_NEED_JEMALLOC=0 \
+ TF_NEED_GCP=1 \
+ TF_NEED_CUDA=0 \
+ TF_NEED_HDFS=0 \
+ TF_NEED_S3=1 \
+ TF_NEED_OPENCL=0 \
+ TF_NEED_GDR=0 \
+ TF_ENABLE_XLA=0 \
+ TF_NEED_VERBS=0 \
+ TF_NEED_MPI=0
+RUN ./configure
+
+# Build and Install TensorFlow.
+# The 'mkl' option builds with Intel(R) Math Kernel Library (MKL), which detects
+# the platform it is currently running on and takes appropriately optimized
+# paths. The -march=native option is for code that is not in MKL, and assumes
+# this container will be run on the same architecture on which it is built.
+RUN LD_LIBRARY_PATH=${LD_LIBRARY_PATH} \
+ bazel build --config=mkl \
+ --config="opt" \
+ --copt="-march=broadwell" \
+ --copt="-O3" \
+ //tensorflow/tools/pip_package:build_pip_package && \
+ mkdir ${WHL_DIR} && \
+ bazel-bin/tensorflow/tools/pip_package/build_pip_package ${WHL_DIR}
+
+# Clean up Bazel cache when done, but leave the whl.
+# This will upgrade the default Tensorflow version with the Intel MKL version
+RUN pip --no-cache-dir install --upgrade ${WHL_DIR}/tensorflow-*.whl && \
+ rm -rf /root/.cache
+
+WORKDIR /root
+
+#add welcome message with instructions
+
+RUN echo '[ ! -z "$TERM" -a -r /etc/motd ] && cat /etc/issue && cat /etc/motd' \
+ >> /etc/bash.bashrc \
+ ; echo "\
+||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||\n\
+| \n\
+| Docker container running Ubuntu \n\
+| with TensorFlow ${TF_BRANCH} optimized for CPU \n\
+| with Intel(R) MKL \n\
+| \n\
+||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||\n\
+\n "\
+ > /etc/motd
diff --git a/tensorflow/tools/docker/Dockerfile.devel-gpu b/tensorflow/tools/docker/Dockerfile.devel-gpu
index 204b5b4dba..2818b822b8 100644
--- a/tensorflow/tools/docker/Dockerfile.devel-gpu
+++ b/tensorflow/tools/docker/Dockerfile.devel-gpu
@@ -15,6 +15,8 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
git \
libcudnn7=7.1.4.18-1+cuda9.0 \
libcudnn7-dev=7.1.4.18-1+cuda9.0 \
+ libnccl2=2.2.13-1+cuda9.0 \
+ libnccl-dev=2.2.13-1+cuda9.0 \
libcurl3-dev \
libfreetype6-dev \
libhdf5-serial-dev \
@@ -72,7 +74,7 @@ RUN echo "startup --batch" >>/etc/bazel.bazelrc
RUN echo "build --spawn_strategy=standalone --genrule_strategy=standalone" \
>>/etc/bazel.bazelrc
# Install the most recent bazel release.
-ENV BAZEL_VERSION 0.11.0
+ENV BAZEL_VERSION 0.14.1
WORKDIR /
RUN mkdir /bazel && \
cd /bazel && \
diff --git a/tensorflow/tools/docker/Dockerfile.devel-gpu-cuda9-cudnn7 b/tensorflow/tools/docker/Dockerfile.devel-gpu-cuda9-cudnn7
new file mode 100644
index 0000000000..3bedc8cf34
--- /dev/null
+++ b/tensorflow/tools/docker/Dockerfile.devel-gpu-cuda9-cudnn7
@@ -0,0 +1,115 @@
+FROM nvidia/cuda:9.0-cudnn7-devel-ubuntu16.04
+
+LABEL maintainer="Gunhan Gulsoy <gunan@google.com>"
+
+# It is possible to override these for releases.
+ARG TF_BRANCH=master
+ARG BAZEL_VERSION=0.5.4
+ARG TF_AVAILABLE_CPUS=32
+
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ build-essential \
+ curl \
+ git \
+ golang \
+ libcurl3-dev \
+ libfreetype6-dev \
+ libpng12-dev \
+ libzmq3-dev \
+ pkg-config \
+ python-dev \
+ python-pip \
+ rsync \
+ software-properties-common \
+ unzip \
+ zip \
+ zlib1g-dev \
+ openjdk-8-jdk \
+ openjdk-8-jre-headless \
+ wget \
+ && \
+ apt-get clean && \
+ rm -rf /var/lib/apt/lists/*
+
+RUN pip --no-cache-dir install --upgrade \
+ pip setuptools
+
+RUN pip --no-cache-dir install \
+ ipykernel \
+ jupyter \
+ matplotlib \
+ numpy \
+ scipy \
+ sklearn \
+ pandas \
+ wheel \
+ && \
+ python -m ipykernel.kernelspec
+
+# Set up our notebook config.
+COPY jupyter_notebook_config.py /root/.jupyter/
+
+# Jupyter has issues with being run directly:
+# https://github.com/ipython/ipython/issues/7062
+# We just add a little wrapper script.
+COPY run_jupyter.sh /
+
+# Set up Bazel.
+
+# Running bazel inside a `docker build` command causes trouble, cf:
+# https://github.com/bazelbuild/bazel/issues/134
+# The easiest solution is to set up a bazelrc file forcing --batch.
+RUN echo "startup --batch" >>/etc/bazel.bazelrc
+# Similarly, we need to workaround sandboxing issues:
+# https://github.com/bazelbuild/bazel/issues/418
+RUN echo "build --spawn_strategy=standalone --genrule_strategy=standalone" \
+ >>/etc/bazel.bazelrc
+WORKDIR /
+RUN mkdir /bazel && \
+ cd /bazel && \
+ wget --quiet https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \
+ wget --quiet https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE && \
+ chmod +x bazel-*.sh && \
+ ./bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \
+ rm -f /bazel/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh
+
+# Download and build TensorFlow.
+WORKDIR /
+RUN git clone https://github.com/tensorflow/tensorflow.git && \
+ cd tensorflow && \
+ git checkout ${TF_BRANCH}
+WORKDIR /tensorflow
+
+# Configure the build for our CUDA configuration.
+ENV CI_BUILD_PYTHON=python \
+ LD_LIBRARY_PATH=/usr/local/cuda/extras/CUPTI/lib64:${LD_LIBRARY_PATH} \
+ CUDNN_INSTALL_PATH=/usr/lib/x86_64-linux-gnu \
+ PYTHON_BIN_PATH=/usr/bin/python \
+ PYTHON_LIB_PATH=/usr/local/lib/python2.7/dist-packages \
+ TF_NEED_CUDA=1 \
+ TF_CUDA_VERSION=9.0 \
+ TF_CUDA_COMPUTE_CAPABILITIES=3.0,3.5,5.2,6.0,6.1,7.0 \
+ TF_CUDNN_VERSION=7
+RUN ./configure
+
+# Build and Install TensorFlow.
+RUN ln -s /usr/local/cuda/lib64/stubs/libcuda.so /usr/local/cuda/lib64/stubs/libcuda.so.1 && \
+ LD_LIBRARY_PATH=/usr/local/cuda/lib64/stubs:${LD_LIBRARY_PATH} \
+ bazel build -c opt \
+ --config=cuda \
+ --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0" \
+ --jobs=${TF_AVAILABLE_CPUS} \
+ tensorflow/tools/pip_package:build_pip_package && \
+ mkdir /pip_pkg && \
+ bazel-bin/tensorflow/tools/pip_package/build_pip_package /pip_pkg && \
+ pip --no-cache-dir install --upgrade /pip_pkg/tensorflow-*.whl && \
+ rm -rf /pip_pkg && \
+ rm -rf /root/.cache
+# Clean up pip wheel and Bazel cache when done.
+
+WORKDIR /root
+
+# TensorBoard
+EXPOSE 6006
+# IPython
+EXPOSE 8888
diff --git a/tensorflow/tools/docker/Dockerfile.devel-mkl b/tensorflow/tools/docker/Dockerfile.devel-mkl
index 6dca0e393f..c85641b383 100755
--- a/tensorflow/tools/docker/Dockerfile.devel-mkl
+++ b/tensorflow/tools/docker/Dockerfile.devel-mkl
@@ -73,7 +73,7 @@ RUN echo "startup --batch" >>/etc/bazel.bazelrc
RUN echo "build --spawn_strategy=standalone --genrule_strategy=standalone" \
>>/etc/bazel.bazelrc
# Install the most recent bazel release.
-ENV BAZEL_VERSION 0.11.0
+ENV BAZEL_VERSION 0.14.1
WORKDIR /
RUN mkdir /bazel && \
cd /bazel && \
diff --git a/tensorflow/tools/docker/Dockerfile.gpu b/tensorflow/tools/docker/Dockerfile.gpu
index 9197651ff4..28d4371da3 100644
--- a/tensorflow/tools/docker/Dockerfile.gpu
+++ b/tensorflow/tools/docker/Dockerfile.gpu
@@ -13,6 +13,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
cuda-cusparse-9-0 \
curl \
libcudnn7=7.1.4.18-1+cuda9.0 \
+ libnccl2=2.2.13-1+cuda9.0 \
libfreetype6-dev \
libhdf5-serial-dev \
libpng12-dev \
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index bff3042990..63323c8623 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -104,6 +104,7 @@ COMMON_PIP_DEPS = [
"//tensorflow/python/kernel_tests/testdata:self_adjoint_eig_op_test_files",
"//tensorflow/python/saved_model:saved_model",
"//tensorflow/python/tools:tools_pip",
+ "//tensorflow/python/tools/api/generator:create_python_api",
"//tensorflow/python:test_ops",
"//tensorflow/tools/dist_test/server:grpc_tensorflow_server",
]
diff --git a/tensorflow/tools/pip_package/build_pip_package.sh b/tensorflow/tools/pip_package/build_pip_package.sh
index 9e41514cfa..b0089d3360 100755
--- a/tensorflow/tools/pip_package/build_pip_package.sh
+++ b/tensorflow/tools/pip_package/build_pip_package.sh
@@ -27,7 +27,7 @@ function cp_external() {
pushd .
cd "$src_dir"
- for f in `find . ! -type d ! -name '*.py' ! -name '*local_config_cuda*' ! -name '*local_config_tensorrt*' ! -name '*org_tensorflow*'`; do
+ for f in `find . ! -type d ! -name '*.py' ! -path '*local_config_cuda*' ! -path '*local_config_tensorrt*' ! -path '*org_tensorflow*'`; do
mkdir -p "${dest_dir}/$(dirname ${f})"
cp "${f}" "${dest_dir}/$(dirname ${f})/"
done
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 4982cc26db..ed654c3285 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -166,11 +166,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "com_github_googlecloudplatform_google_cloud_cpp",
urls = [
- "https://mirror.bazel.build/github.com/GoogleCloudPlatform/google-cloud-cpp/archive/53f822805e77ea7715f5b52c592a162c515c7219.tar.gz",
- "https://github.com/GoogleCloudPlatform/google-cloud-cpp/archive/53f822805e77ea7715f5b52c592a162c515c7219.tar.gz",
+ "https://mirror.bazel.build/github.com/GoogleCloudPlatform/google-cloud-cpp/archive/f875700a023bdd706333cde45aee8758b272c357.tar.gz",
+ "https://github.com/GoogleCloudPlatform/google-cloud-cpp/archive/f875700a023bdd706333cde45aee8758b272c357.tar.gz",
],
- sha256 = "06853bfca77ef4aec09db5ab48c548f68ef2e18f17404cbce61f8d9b820f951b",
- strip_prefix = "google-cloud-cpp-53f822805e77ea7715f5b52c592a162c515c7219",
+ sha256 = "a34f3c50b237686dc870b13baaa6a5836ce3473f2f2a02717299f0ff318372db",
+ strip_prefix = "google-cloud-cpp-f875700a023bdd706333cde45aee8758b272c357",
)
tf_http_archive(
@@ -219,12 +219,12 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "nasm",
urls = [
- "https://mirror.bazel.build/www.nasm.us/pub/nasm/releasebuilds/2.12.02/nasm-2.12.02.tar.bz2",
- "http://pkgs.fedoraproject.org/repo/pkgs/nasm/nasm-2.12.02.tar.bz2/d15843c3fb7db39af80571ee27ec6fad/nasm-2.12.02.tar.bz2",
- "http://www.nasm.us/pub/nasm/releasebuilds/2.12.02/nasm-2.12.02.tar.bz2",
+ "https://mirror.bazel.build/www.nasm.us/pub/nasm/releasebuilds/2.13.03/nasm-2.13.03.tar.bz2",
+ "http://pkgs.fedoraproject.org/repo/pkgs/nasm/nasm-2.13.03.tar.bz2/sha512/d7a6b4cee8dfd603d8d4c976e5287b5cc542fa0b466ff989b743276a6e28114e64289bf02a7819eca63142a5278aa6eed57773007e5f589e15768e6456a8919d/nasm-2.13.03.tar.bz2",
+ "http://www.nasm.us/pub/nasm/releasebuilds/2.13.03/nasm-2.13.03.tar.bz2",
],
- sha256 = "00b0891c678c065446ca59bcee64719d0096d54d6886e6e472aeee2e170ae324",
- strip_prefix = "nasm-2.12.02",
+ sha256 = "63ec86477ad3f0f6292325fd89e1d93aea2e2fd490070863f17d48f7cd387011",
+ strip_prefix = "nasm-2.13.03",
build_file = clean_dep("//third_party:nasm.BUILD"),
)
@@ -254,11 +254,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "org_sqlite",
urls = [
- "https://mirror.bazel.build/www.sqlite.org/2018/sqlite-amalgamation-3230100.zip",
- "https://www.sqlite.org/2018/sqlite-amalgamation-3230100.zip",
+ "https://mirror.bazel.build/www.sqlite.org/2018/sqlite-amalgamation-3240000.zip",
+ "https://www.sqlite.org/2018/sqlite-amalgamation-3240000.zip",
],
- sha256 = "4239a1f69e5721d07d9a374eb84d594225229e54be4ee628da2995f4315d8dfc",
- strip_prefix = "sqlite-amalgamation-3230100",
+ sha256 = "ad68c1216c3a474cf360c7581a4001e952515b3649342100f2d7ca7c8e313da6",
+ strip_prefix = "sqlite-amalgamation-3240000",
build_file = clean_dep("//third_party:sqlite.BUILD"),
)
@@ -405,11 +405,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "com_github_gflags_gflags",
urls = [
- "https://mirror.bazel.build/github.com/gflags/gflags/archive/f8a0efe03aa69b3336d8e228b37d4ccb17324b88.tar.gz",
- "https://github.com/gflags/gflags/archive/f8a0efe03aa69b3336d8e228b37d4ccb17324b88.tar.gz",
+ "https://mirror.bazel.build/github.com/gflags/gflags/archive/v2.2.1.tar.gz",
+ "https://github.com/gflags/gflags/archive/v2.2.1.tar.gz",
],
- sha256 = "4d222fab8f1ede4709cdff417d15a1336f862d7334a81abf76d09c15ecf9acd1",
- strip_prefix = "gflags-f8a0efe03aa69b3336d8e228b37d4ccb17324b88",
+ sha256 = "ae27cdbcd6a2f935baa78e4f21f675649271634c092b1be01469440495609d0e",
+ strip_prefix = "gflags-2.2.1",
)
tf_http_archive(
@@ -456,7 +456,6 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
strip_prefix = "grpc-1.13.0",
)
-
tf_http_archive(
name = "linenoise",
sha256 = "7f51f45887a3d31b4ce4fa5965210a5e64637ceac12720cfce7954d6a2e812f7",
@@ -473,11 +472,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "llvm",
urls = [
- "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/d5d94ca3a7f8526c2e4e5f663f9dc79ae5d39d93.tar.gz",
- "https://github.com/llvm-mirror/llvm/archive/d5d94ca3a7f8526c2e4e5f663f9dc79ae5d39d93.tar.gz",
+ "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/bd8c8d759852871609ba2e4e79868420f751949d.tar.gz",
+ "https://github.com/llvm-mirror/llvm/archive/bd8c8d759852871609ba2e4e79868420f751949d.tar.gz",
],
- sha256 = "280fdc888e2eb88a3a8cc4e7d3034fffc87f98e3e686be31f8c719c6e5b67d2d",
- strip_prefix = "llvm-d5d94ca3a7f8526c2e4e5f663f9dc79ae5d39d93",
+ sha256 = "0c63e8583b213543309e8577ffe87a0cf34cc22269630d2c5c2f0a2345fda4a8",
+ strip_prefix = "llvm-bd8c8d759852871609ba2e4e79868420f751949d",
build_file = clean_dep("//third_party/llvm:llvm.autogenerated.BUILD"),
)
@@ -683,12 +682,12 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "cython",
- sha256 = "6dcd30b5ceb887b2b965ee7ceb82ea3acb5f0642fe2206c7636b45acea4798e5",
+ sha256 = "bccc9aa050ea02595b2440188813b936eaf345e85fb9692790cecfe095cf91aa",
urls = [
- "https://mirror.bazel.build/github.com/cython/cython/archive/3732784c45cfb040a5b0936951d196f83a12ea17.tar.gz",
- "https://github.com/cython/cython/archive/3732784c45cfb040a5b0936951d196f83a12ea17.tar.gz",
+ "https://mirror.bazel.build/github.com/cython/cython/archive/0.28.4.tar.gz",
+ "https://github.com/cython/cython/archive/0.28.4.tar.gz",
],
- strip_prefix = "cython-3732784c45cfb040a5b0936951d196f83a12ea17",
+ strip_prefix = "cython-0.28.4",
build_file = clean_dep("//third_party:cython.BUILD"),
delete = ["BUILD.bazel"],
)
@@ -696,11 +695,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "bazel_toolchains",
urls = [
- "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/44200e0c026d86c53470d107b3697a3e46469c43.tar.gz",
- "https://github.com/bazelbuild/bazel-toolchains/archive/44200e0c026d86c53470d107b3697a3e46469c43.tar.gz",
+ "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/37acf1841ab1475c98a152cb9e446460c8ae29e1.tar.gz",
+ "https://github.com/bazelbuild/bazel-toolchains/archive/37acf1841ab1475c98a152cb9e446460c8ae29e1.tar.gz",
],
- strip_prefix = "bazel-toolchains-44200e0c026d86c53470d107b3697a3e46469c43",
- sha256 = "699b55a6916c687f4b7dc092dbbf5f64672cde0dc965f79717735ec4e5416556",
+ strip_prefix = "bazel-toolchains-37acf1841ab1475c98a152cb9e446460c8ae29e1",
+ sha256 = "3b604699685c5c65dd3f6f17425570a4b2f00ddba2f750db15acc72e55bb098b",
)
tf_http_archive(
diff --git a/third_party/clang_toolchain/download_clang.bzl b/third_party/clang_toolchain/download_clang.bzl
index a014a806a6..ab57b9dfa0 100644
--- a/third_party/clang_toolchain/download_clang.bzl
+++ b/third_party/clang_toolchain/download_clang.bzl
@@ -35,18 +35,18 @@ def download_clang(repo_ctx, out_folder):
# Latest CLANG_REVISION and CLANG_SUB_REVISION of the Chromiums's release
# can be found in https://chromium.googlesource.com/chromium/src/tools/clang/+/master/scripts/update.py
- CLANG_REVISION = '335091'
+ CLANG_REVISION = '336424'
CLANG_SUB_REVISION = 1
package_version = '%s-%s' % (CLANG_REVISION, CLANG_SUB_REVISION)
checksums = {
'Linux_x64':
- '17002b75293fccfdd175eacdc9ee47d97b58d7e98fef343384fbbef1b68ce99f',
+ '2ea97e047470da648f5d078af008bce6891287592382cee3d53a1187d996da94',
'Mac':
- '9351e46d28315daaa06a1eb55bd0370ed4aaeb693a2a3e82e48d2737d7723468',
+ 'c6e28909cce63ee35e0d51284d9f0f6e8838f7fb8b7a0dc9536c2ea900552df0',
'Win':
- 'e78a1e469224d6f6751b4df4374bf58893ac03900ec924e4c8264888ba4aeb1e',
+ '1299fda7c4378bfb81337f7e5f351c8a1f953f51e0744e2170454b8d722f3db7',
}
platform_folder = _get_platform_folder(repo_ctx.os.name)
diff --git a/third_party/codegen.BUILD b/third_party/codegen.BUILD
new file mode 100644
index 0000000000..df436c8163
--- /dev/null
+++ b/third_party/codegen.BUILD
@@ -0,0 +1,16 @@
+# -*- mode: python; -*-
+#
+# Description:
+# Extension to ast that allow ast -> python code generation.
+
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # New BSD
+
+exports_files(["LICENSE"])
+
+py_library(
+ name = "com_github_andreif_codegen",
+ srcs = glob(["codegen.py"]),
+ srcs_version = "PY2AND3",
+)
diff --git a/third_party/gpus/crosstool/BUILD.tpl b/third_party/gpus/crosstool/BUILD.tpl
index 98cb326572..f638756d23 100644
--- a/third_party/gpus/crosstool/BUILD.tpl
+++ b/third_party/gpus/crosstool/BUILD.tpl
@@ -7,6 +7,7 @@ cc_toolchain_suite(
toolchains = {
"local|compiler": ":cc-compiler-local",
"darwin|compiler": ":cc-compiler-darwin",
+ "x64_windows|msvc-cl": ":cc-compiler-windows",
},
)
@@ -42,6 +43,20 @@ cc_toolchain(
supports_param_files = 0,
)
+cc_toolchain(
+ name = "cc-compiler-windows",
+ all_files = "%{win_linker_files}",
+ compiler_files = ":empty",
+ cpu = "x64_windows",
+ dwp_files = ":empty",
+ dynamic_runtime_libs = [":empty"],
+ linker_files = "%{win_linker_files}",
+ objcopy_files = ":empty",
+ static_runtime_libs = [":empty"],
+ strip_files = ":empty",
+ supports_param_files = 1,
+)
+
filegroup(
name = "empty",
srcs = [],
@@ -51,3 +66,8 @@ filegroup(
name = "crosstool_wrapper_driver_is_not_gcc",
srcs = ["clang/bin/crosstool_wrapper_driver_is_not_gcc"],
)
+
+filegroup(
+ name = "windows_msvc_wrapper_files",
+ srcs = glob(["windows/msvc_*"]),
+)
diff --git a/third_party/gpus/crosstool/CROSSTOOL.tpl b/third_party/gpus/crosstool/CROSSTOOL.tpl
index 1424ff6511..3972c96a2f 100644
--- a/third_party/gpus/crosstool/CROSSTOOL.tpl
+++ b/third_party/gpus/crosstool/CROSSTOOL.tpl
@@ -22,6 +22,10 @@ default_toolchain {
cpu: "ppc"
toolchain_identifier: "local_linux"
}
+default_toolchain {
+ cpu: "x64_windows"
+ toolchain_identifier: "local_windows"
+}
toolchain {
abi_version: "local"
@@ -537,3 +541,868 @@ toolchain {
%{host_compiler_includes}
}
+
+toolchain {
+ toolchain_identifier: "local_windows"
+ host_system_name: "local"
+ target_system_name: "local"
+
+ abi_version: "local"
+ abi_libc_version: "local"
+ target_cpu: "x64_windows"
+ compiler: "msvc-cl"
+ target_libc: "msvcrt"
+
+%{cxx_builtin_include_directory}
+
+ tool_path {
+ name: "ar"
+ path: "%{msvc_lib_path}"
+ }
+ tool_path {
+ name: "ml"
+ path: "%{msvc_ml_path}"
+ }
+ tool_path {
+ name: "cpp"
+ path: "%{msvc_cl_path}"
+ }
+ tool_path {
+ name: "gcc"
+ path: "%{msvc_cl_path}"
+ }
+ tool_path {
+ name: "gcov"
+ path: "wrapper/bin/msvc_nop.bat"
+ }
+ tool_path {
+ name: "ld"
+ path: "%{msvc_link_path}"
+ }
+ tool_path {
+ name: "nm"
+ path: "wrapper/bin/msvc_nop.bat"
+ }
+ tool_path {
+ name: "objcopy"
+ path: "wrapper/bin/msvc_nop.bat"
+ }
+ tool_path {
+ name: "objdump"
+ path: "wrapper/bin/msvc_nop.bat"
+ }
+ tool_path {
+ name: "strip"
+ path: "wrapper/bin/msvc_nop.bat"
+ }
+ supports_interface_shared_objects: true
+
+ # TODO(pcloudy): Review those flags below, they should be defined by cl.exe
+ compiler_flag: "/DCOMPILER_MSVC"
+
+ # Don't define min/max macros in windows.h.
+ compiler_flag: "/DNOMINMAX"
+
+ # Platform defines.
+ compiler_flag: "/D_WIN32_WINNT=0x0600"
+ # Turn off warning messages.
+ compiler_flag: "/D_CRT_SECURE_NO_DEPRECATE"
+ compiler_flag: "/D_CRT_SECURE_NO_WARNINGS"
+ compiler_flag: "/D_SILENCE_STDEXT_HASH_DEPRECATION_WARNINGS"
+
+ # Useful options to have on for compilation.
+ # Increase the capacity of object files to 2^32 sections.
+ compiler_flag: "/bigobj"
+ # Allocate 500MB for precomputed headers.
+ compiler_flag: "/Zm500"
+ # Use unsigned char by default.
+ compiler_flag: "/J"
+ # Use function level linking.
+ compiler_flag: "/Gy"
+ # Use string pooling.
+ compiler_flag: "/GF"
+ # Catch C++ exceptions only and tell the compiler to assume that functions declared
+ # as extern "C" never throw a C++ exception.
+ compiler_flag: "/EHsc"
+
+ # Globally disabled warnings.
+ # Don't warn about elements of array being be default initialized.
+ compiler_flag: "/wd4351"
+ # Don't warn about no matching delete found.
+ compiler_flag: "/wd4291"
+ # Don't warn about diamond inheritance patterns.
+ compiler_flag: "/wd4250"
+ # Don't warn about insecure functions (e.g. non _s functions).
+ compiler_flag: "/wd4996"
+
+ linker_flag: "/MACHINE:X64"
+
+ feature {
+ name: "no_legacy_features"
+ }
+
+ # Suppress startup banner.
+ feature {
+ name: "nologo"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ action: "c++-module-compile"
+ action: "c++-module-codegen"
+ action: "c++-header-parsing"
+ action: "assemble"
+ action: "preprocess-assemble"
+ action: "c++-link-executable"
+ action: "c++-link-dynamic-library"
+ action: "c++-link-nodeps-dynamic-library"
+ action: "c++-link-static-library"
+ flag_group {
+ flag: "/nologo"
+ }
+ }
+ }
+
+ feature {
+ name: 'has_configured_linker_path'
+ }
+
+ # This feature indicates strip is not supported, building stripped binary will just result a copy of orignial binary
+ feature {
+ name: 'no_stripping'
+ }
+
+ # This feature indicates this is a toolchain targeting Windows.
+ feature {
+ name: 'targets_windows'
+ implies: 'copy_dynamic_libraries_to_binary'
+ enabled: true
+ }
+
+ feature {
+ name: 'copy_dynamic_libraries_to_binary'
+ }
+
+ action_config {
+ config_name: 'assemble'
+ action_name: 'assemble'
+ tool {
+ tool_path: '%{msvc_ml_path}'
+ }
+ implies: 'compiler_input_flags'
+ implies: 'compiler_output_flags'
+ implies: 'nologo'
+ implies: 'msvc_env'
+ implies: 'sysroot'
+ }
+
+ action_config {
+ config_name: 'preprocess-assemble'
+ action_name: 'preprocess-assemble'
+ tool {
+ tool_path: '%{msvc_ml_path}'
+ }
+ implies: 'compiler_input_flags'
+ implies: 'compiler_output_flags'
+ implies: 'nologo'
+ implies: 'msvc_env'
+ implies: 'sysroot'
+ }
+
+ action_config {
+ config_name: 'c-compile'
+ action_name: 'c-compile'
+ tool {
+ tool_path: '%{msvc_cl_path}'
+ }
+ implies: 'compiler_input_flags'
+ implies: 'compiler_output_flags'
+ implies: 'legacy_compile_flags'
+ implies: 'nologo'
+ implies: 'msvc_env'
+ implies: 'parse_showincludes'
+ implies: 'user_compile_flags'
+ implies: 'sysroot'
+ implies: 'unfiltered_compile_flags'
+ }
+
+ action_config {
+ config_name: 'c++-compile'
+ action_name: 'c++-compile'
+ tool {
+ tool_path: '%{msvc_cl_path}'
+ }
+ implies: 'compiler_input_flags'
+ implies: 'compiler_output_flags'
+ implies: 'legacy_compile_flags'
+ implies: 'nologo'
+ implies: 'msvc_env'
+ implies: 'parse_showincludes'
+ implies: 'user_compile_flags'
+ implies: 'sysroot'
+ implies: 'unfiltered_compile_flags'
+ }
+
+ action_config {
+ config_name: 'c++-link-executable'
+ action_name: 'c++-link-executable'
+ tool {
+ tool_path: '%{msvc_link_path}'
+ }
+ implies: 'nologo'
+ implies: 'linkstamps'
+ implies: 'output_execpath_flags'
+ implies: 'input_param_flags'
+ implies: 'user_link_flags'
+ implies: 'legacy_link_flags'
+ implies: 'linker_subsystem_flag'
+ implies: 'linker_param_file'
+ implies: 'msvc_env'
+ implies: 'no_stripping'
+ }
+
+ action_config {
+ config_name: 'c++-link-dynamic-library'
+ action_name: 'c++-link-dynamic-library'
+ tool {
+ tool_path: '%{msvc_link_path}'
+ }
+ implies: 'nologo'
+ implies: 'shared_flag'
+ implies: 'linkstamps'
+ implies: 'output_execpath_flags'
+ implies: 'input_param_flags'
+ implies: 'user_link_flags'
+ implies: 'legacy_link_flags'
+ implies: 'linker_subsystem_flag'
+ implies: 'linker_param_file'
+ implies: 'msvc_env'
+ implies: 'no_stripping'
+ implies: 'has_configured_linker_path'
+ implies: 'def_file'
+ }
+
+ action_config {
+ config_name: 'c++-link-nodeps-dynamic-library'
+ action_name: 'c++-link-nodeps-dynamic-library'
+ tool {
+ tool_path: '%{msvc_link_path}'
+ }
+ implies: 'nologo'
+ implies: 'shared_flag'
+ implies: 'linkstamps'
+ implies: 'output_execpath_flags'
+ implies: 'input_param_flags'
+ implies: 'user_link_flags'
+ implies: 'legacy_link_flags'
+ implies: 'linker_subsystem_flag'
+ implies: 'linker_param_file'
+ implies: 'msvc_env'
+ implies: 'no_stripping'
+ implies: 'has_configured_linker_path'
+ implies: 'def_file'
+ }
+
+ action_config {
+ config_name: 'c++-link-static-library'
+ action_name: 'c++-link-static-library'
+ tool {
+ tool_path: '%{msvc_lib_path}'
+ }
+ implies: 'nologo'
+ implies: 'archiver_flags'
+ implies: 'input_param_flags'
+ implies: 'linker_param_file'
+ implies: 'msvc_env'
+ }
+
+ # TODO(b/65151735): Remove legacy_compile_flags feature when legacy fields are
+ # not used in this crosstool
+ feature {
+ name: 'legacy_compile_flags'
+ flag_set {
+ expand_if_all_available: 'legacy_compile_flags'
+ action: 'preprocess-assemble'
+ action: 'c-compile'
+ action: 'c++-compile'
+ action: 'c++-header-parsing'
+ action: 'c++-module-compile'
+ action: 'c++-module-codegen'
+ flag_group {
+ iterate_over: 'legacy_compile_flags'
+ flag: '%{legacy_compile_flags}'
+ }
+ }
+ }
+
+ feature {
+ name: "msvc_env"
+ env_set {
+ action: "c-compile"
+ action: "c++-compile"
+ action: "c++-module-compile"
+ action: "c++-module-codegen"
+ action: "c++-header-parsing"
+ action: "assemble"
+ action: "preprocess-assemble"
+ action: "c++-link-executable"
+ action: "c++-link-dynamic-library"
+ action: "c++-link-nodeps-dynamic-library"
+ action: "c++-link-static-library"
+ env_entry {
+ key: "PATH"
+ value: "%{msvc_env_path}"
+ }
+ env_entry {
+ key: "INCLUDE"
+ value: "%{msvc_env_include}"
+ }
+ env_entry {
+ key: "LIB"
+ value: "%{msvc_env_lib}"
+ }
+ env_entry {
+ key: "TMP"
+ value: "%{msvc_env_tmp}"
+ }
+ env_entry {
+ key: "TEMP"
+ value: "%{msvc_env_tmp}"
+ }
+ }
+ }
+
+ feature {
+ name: 'include_paths'
+ flag_set {
+ action: "assemble"
+ action: 'preprocess-assemble'
+ action: 'c-compile'
+ action: 'c++-compile'
+ action: 'c++-header-parsing'
+ action: 'c++-module-compile'
+ flag_group {
+ iterate_over: 'quote_include_paths'
+ flag: '/I%{quote_include_paths}'
+ }
+ flag_group {
+ iterate_over: 'include_paths'
+ flag: '/I%{include_paths}'
+ }
+ flag_group {
+ iterate_over: 'system_include_paths'
+ flag: '/I%{system_include_paths}'
+ }
+ }
+ }
+
+ feature {
+ name: "preprocessor_defines"
+ flag_set {
+ action: "assemble"
+ action: "preprocess-assemble"
+ action: "c-compile"
+ action: "c++-compile"
+ action: "c++-header-parsing"
+ action: "c++-module-compile"
+ flag_group {
+ flag: "/D%{preprocessor_defines}"
+ iterate_over: "preprocessor_defines"
+ }
+ }
+ }
+
+ # Tell Bazel to parse the output of /showIncludes
+ feature {
+ name: 'parse_showincludes'
+ flag_set {
+ action: 'preprocess-assemble'
+ action: 'c-compile'
+ action: 'c++-compile'
+ action: 'c++-module-compile'
+ action: 'c++-header-parsing'
+ flag_group {
+ flag: "/showIncludes"
+ }
+ }
+ }
+
+
+ feature {
+ name: 'generate_pdb_file'
+ requires: {
+ feature: 'dbg'
+ }
+ requires: {
+ feature: 'fastbuild'
+ }
+ }
+
+ feature {
+ name: 'shared_flag'
+ flag_set {
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: '/DLL'
+ }
+ }
+ }
+
+ feature {
+ name: 'linkstamps'
+ flag_set {
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ expand_if_all_available: 'linkstamp_paths'
+ flag_group {
+ iterate_over: 'linkstamp_paths'
+ flag: '%{linkstamp_paths}'
+ }
+ }
+ }
+
+ feature {
+ name: 'output_execpath_flags'
+ flag_set {
+ expand_if_all_available: 'output_execpath'
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: '/OUT:%{output_execpath}'
+ }
+ }
+ }
+
+ feature {
+ name: 'archiver_flags'
+ flag_set {
+ expand_if_all_available: 'output_execpath'
+ action: 'c++-link-static-library'
+ flag_group {
+ flag: '/OUT:%{output_execpath}'
+ }
+ }
+ }
+
+ feature {
+ name: 'input_param_flags'
+ flag_set {
+ expand_if_all_available: 'interface_library_output_path'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "/IMPLIB:%{interface_library_output_path}"
+ }
+ }
+ flag_set {
+ expand_if_all_available: 'libopts'
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ iterate_over: 'libopts'
+ flag: '%{libopts}'
+ }
+ }
+ flag_set {
+ expand_if_all_available: 'libraries_to_link'
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ action: 'c++-link-static-library'
+ flag_group {
+ iterate_over: 'libraries_to_link'
+ flag_group {
+ expand_if_equal: {
+ variable: 'libraries_to_link.type'
+ value: 'object_file_group'
+ }
+ iterate_over: 'libraries_to_link.object_files'
+ flag_group {
+ flag: '%{libraries_to_link.object_files}'
+ }
+ }
+ flag_group {
+ expand_if_equal: {
+ variable: 'libraries_to_link.type'
+ value: 'object_file'
+ }
+ flag_group {
+ flag: '%{libraries_to_link.name}'
+ }
+ }
+ flag_group {
+ expand_if_equal: {
+ variable: 'libraries_to_link.type'
+ value: 'interface_library'
+ }
+ flag_group {
+ flag: '%{libraries_to_link.name}'
+ }
+ }
+ flag_group {
+ expand_if_equal: {
+ variable: 'libraries_to_link.type'
+ value: 'static_library'
+ }
+ flag_group {
+ expand_if_false: 'libraries_to_link.is_whole_archive'
+ flag: '%{libraries_to_link.name}'
+ }
+ flag_group {
+ expand_if_true: 'libraries_to_link.is_whole_archive'
+ flag: '/WHOLEARCHIVE:%{libraries_to_link.name}'
+ }
+ }
+ }
+ }
+ }
+
+ # Since this feature is declared earlier in the CROSSTOOL than
+ # "user_link_flags", this feature will be applied prior to it anwyhere they
+ # are both implied. And since "user_link_flags" contains the linkopts from
+ # the build rule, this allows the user to override the /SUBSYSTEM in the BUILD
+ # file.
+ feature {
+ name: 'linker_subsystem_flag'
+ flag_set {
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: '/SUBSYSTEM:CONSOLE'
+ }
+ }
+ }
+
+ # The "user_link_flags" contains user-defined linkopts (from build rules)
+ # so it should be defined after features that declare user-overridable flags.
+ # For example the "linker_subsystem_flag" defines a default "/SUBSYSTEM" flag
+ # but we want to let the user override it, therefore "link_flag_subsystem" is
+ # defined earlier in the CROSSTOOL file than "user_link_flags".
+ feature {
+ name: 'user_link_flags'
+ flag_set {
+ expand_if_all_available: 'user_link_flags'
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ iterate_over: 'user_link_flags'
+ flag: '%{user_link_flags}'
+ }
+ }
+ }
+ feature {
+ name: 'legacy_link_flags'
+ flag_set {
+ expand_if_all_available: 'legacy_link_flags'
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ iterate_over: 'legacy_link_flags'
+ flag: '%{legacy_link_flags}'
+ }
+ }
+ }
+
+ feature {
+ name: 'linker_param_file'
+ flag_set {
+ expand_if_all_available: 'linker_param_file'
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ action: 'c++-link-static-library'
+ flag_group {
+ flag: '@%{linker_param_file}'
+ }
+ }
+ }
+
+ feature {
+ name: 'static_link_msvcrt'
+ }
+
+ feature {
+ name: 'static_link_msvcrt_no_debug'
+ flag_set {
+ action: 'c-compile'
+ action: 'c++-compile'
+ flag_group {
+ flag: "/MT"
+ }
+ }
+ flag_set {
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "/DEFAULTLIB:libcmt.lib"
+ }
+ }
+ requires: { feature: 'fastbuild'}
+ requires: { feature: 'opt'}
+ }
+
+ feature {
+ name: 'dynamic_link_msvcrt_no_debug'
+ flag_set {
+ action: 'c-compile'
+ action: 'c++-compile'
+ flag_group {
+ flag: "/MD"
+ }
+ }
+ flag_set {
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "/DEFAULTLIB:msvcrt.lib"
+ }
+ }
+ requires: { feature: 'fastbuild'}
+ requires: { feature: 'opt'}
+ }
+
+ feature {
+ name: 'static_link_msvcrt_debug'
+ flag_set {
+ action: 'c-compile'
+ action: 'c++-compile'
+ flag_group {
+ flag: "/MTd"
+ }
+ }
+ flag_set {
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "/DEFAULTLIB:libcmtd.lib"
+ }
+ }
+ requires: { feature: 'dbg'}
+ }
+
+ feature {
+ name: 'dynamic_link_msvcrt_debug'
+ flag_set {
+ action: 'c-compile'
+ action: 'c++-compile'
+ flag_group {
+ flag: "/MDd"
+ }
+ }
+ flag_set {
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "/DEFAULTLIB:msvcrtd.lib"
+ }
+ }
+ requires: { feature: 'dbg'}
+ }
+
+ feature {
+ name: 'dbg'
+ flag_set {
+ action: 'c-compile'
+ action: 'c++-compile'
+ flag_group {
+ flag: "/Od"
+ flag: "/Z7"
+ flag: "/DDEBUG"
+ }
+ }
+ flag_set {
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "/DEBUG:FULL"
+ flag: "/INCREMENTAL:NO"
+ }
+ }
+ implies: 'generate_pdb_file'
+ }
+
+ feature {
+ name: 'fastbuild'
+ flag_set {
+ action: 'c-compile'
+ action: 'c++-compile'
+ flag_group {
+ flag: "/Od"
+ flag: "/Z7"
+ flag: "/DDEBUG"
+ }
+ }
+ flag_set {
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "/DEBUG:FASTLINK"
+ flag: "/INCREMENTAL:NO"
+ }
+ }
+ implies: 'generate_pdb_file'
+ }
+
+ feature {
+ name: 'opt'
+ flag_set {
+ action: 'c-compile'
+ action: 'c++-compile'
+ flag_group {
+ flag: "/O2"
+ flag: "/DNDEBUG"
+ }
+ }
+ }
+
+ feature {
+ name: 'user_compile_flags'
+ flag_set {
+ expand_if_all_available: 'user_compile_flags'
+ action: 'preprocess-assemble'
+ action: 'c-compile'
+ action: 'c++-compile'
+ action: 'c++-header-parsing'
+ action: 'c++-module-compile'
+ action: 'c++-module-codegen'
+ flag_group {
+ iterate_over: 'user_compile_flags'
+ flag: '%{user_compile_flags}'
+ }
+ }
+ }
+
+ feature {
+ name: 'sysroot'
+ flag_set {
+ expand_if_all_available: 'sysroot'
+ action: 'assemble'
+ action: 'preprocess-assemble'
+ action: 'c-compile'
+ action: 'c++-compile'
+ action: 'c++-header-parsing'
+ action: 'c++-module-compile'
+ action: 'c++-module-codegen'
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ iterate_over: 'sysroot'
+ flag: '--sysroot=%{sysroot}'
+ }
+ }
+ }
+
+ feature {
+ name: 'unfiltered_compile_flags'
+ flag_set {
+ expand_if_all_available: 'unfiltered_compile_flags'
+ action: 'preprocess-assemble'
+ action: 'c-compile'
+ action: 'c++-compile'
+ action: 'c++-header-parsing'
+ action: 'c++-module-compile'
+ action: 'c++-module-codegen'
+ flag_group {
+ iterate_over: 'unfiltered_compile_flags'
+ flag: '%{unfiltered_compile_flags}'
+ }
+ }
+ }
+
+ feature {
+ name: 'compiler_output_flags'
+ flag_set {
+ action: 'assemble'
+ flag_group {
+ expand_if_all_available: 'output_file'
+ expand_if_none_available: 'output_assembly_file'
+ expand_if_none_available: 'output_preprocess_file'
+ flag: '/Fo%{output_file}'
+ flag: '/Zi'
+ }
+ }
+ flag_set {
+ action: 'preprocess-assemble'
+ action: 'c-compile'
+ action: 'c++-compile'
+ action: 'c++-header-parsing'
+ action: 'c++-module-compile'
+ action: 'c++-module-codegen'
+ flag_group {
+ expand_if_all_available: 'output_file'
+ expand_if_none_available: 'output_assembly_file'
+ expand_if_none_available: 'output_preprocess_file'
+ flag: '/Fo%{output_file}'
+ }
+ flag_group {
+ expand_if_all_available: 'output_file'
+ expand_if_all_available: 'output_assembly_file'
+ flag: '/Fa%{output_file}'
+ }
+ flag_group {
+ expand_if_all_available: 'output_file'
+ expand_if_all_available: 'output_preprocess_file'
+ flag: '/P'
+ flag: '/Fi%{output_file}'
+ }
+ }
+ }
+
+ feature {
+ name: 'compiler_input_flags'
+ flag_set {
+ action: 'assemble'
+ action: 'preprocess-assemble'
+ action: 'c-compile'
+ action: 'c++-compile'
+ action: 'c++-header-parsing'
+ action: 'c++-module-compile'
+ action: 'c++-module-codegen'
+ flag_group {
+ expand_if_all_available: 'source_file'
+ flag: '/c'
+ flag: '%{source_file}'
+ }
+ }
+ }
+
+ feature {
+ name : 'def_file',
+ flag_set {
+ expand_if_all_available: 'def_file_path'
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "/DEF:%{def_file_path}"
+ # We can specify a different DLL name in DEF file, /ignore:4070 suppresses
+ # the warning message about DLL name doesn't match the default one.
+ # See https://msdn.microsoft.com/en-us/library/sfkk2fz7.aspx
+ flag: "/ignore:4070"
+ }
+ }
+ }
+
+ feature {
+ name: 'windows_export_all_symbols'
+ }
+
+ feature {
+ name: 'no_windows_export_all_symbols'
+ }
+
+ linking_mode_flags { mode: DYNAMIC }
+}
diff --git a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
index 2558f46fd5..f4f4d0ee96 100755
--- a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
+++ b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
@@ -175,6 +175,11 @@ def InvokeNvcc(argv, log=False):
# any other reliable way to just get the list of source files to be compiled.
src_files = GetOptionValue(argv, 'c')
+ # Pass -w through from host to nvcc, but don't do anything fancier with
+ # warnings-related flags, since they're not necessarily the same across
+ # compilers.
+ warning_options = ' -w' if '-w' in argv else ''
+
if len(src_files) == 0:
return 1
if len(out_file) != 1:
@@ -205,6 +210,7 @@ def InvokeNvcc(argv, log=False):
nvccopts += defines
nvccopts += std_options
nvccopts += m_options
+ nvccopts += warning_options
if depfiles:
# Generate the dependency file
diff --git a/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.bat.tpl b/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.bat.tpl
new file mode 100644
index 0000000000..8f8fb3e423
--- /dev/null
+++ b/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.bat.tpl
@@ -0,0 +1,20 @@
+:: Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+::
+:: Licensed under the Apache License, Version 2.0 (the "License");
+:: you may not use this file except in compliance with the License.
+:: You may obtain a copy of the License at
+::
+:: http://www.apache.org/licenses/LICENSE-2.0
+::
+:: Unless required by applicable law or agreed to in writing, software
+:: distributed under the License is distributed on an "AS IS" BASIS,
+:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+:: See the License for the specific language governing permissions and
+:: limitations under the License.
+:: =============================================================================
+
+:: Invoke msvc_wrapper_for_nvcc.py, which is located in the same directory.
+@echo OFF
+set arg0=%~0
+for %%F in ("%arg0%") do set DRIVER_BIN=%%~dpF
+"%{python_binary}" -B "%DRIVER_BIN%\msvc_wrapper_for_nvcc.py" %*
diff --git a/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl b/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl
new file mode 100644
index 0000000000..1a09756813
--- /dev/null
+++ b/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl
@@ -0,0 +1,192 @@
+#!/usr/bin/env python
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Crosstool wrapper for compiling CUDA programs with nvcc on Windows.
+
+DESCRIPTION:
+ This script is the Windows version of //third_party/gpus/crosstool/crosstool_wrapper_is_not_gcc
+"""
+
+from __future__ import print_function
+
+from argparse import ArgumentParser
+import os
+import subprocess
+import re
+import sys
+import pipes
+
+# Template values set by cuda_autoconf.
+CPU_COMPILER = ('%{cpu_compiler}')
+GCC_HOST_COMPILER_PATH = ('%{gcc_host_compiler_path}')
+
+NVCC_PATH = '%{nvcc_path}'
+NVCC_VERSION = '%{cuda_version}'
+NVCC_TEMP_DIR = "%{nvcc_tmp_dir}"
+supported_cuda_compute_capabilities = [ %{cuda_compute_capabilities} ]
+
+def Log(s):
+ print('gpus/crosstool: {0}'.format(s))
+
+
+def GetOptionValue(argv, option):
+ """Extract the list of values for option from options.
+
+ Args:
+ option: The option whose value to extract, without the leading '/'.
+
+ Returns:
+ 1. A list of values, either directly following the option,
+ (eg., /opt val1 val2) or values collected from multiple occurrences of
+ the option (eg., /opt val1 /opt val2).
+ 2. The leftover options.
+ """
+
+ parser = ArgumentParser(prefix_chars='/')
+ parser.add_argument('/' + option, nargs='*', action='append')
+ args, leftover = parser.parse_known_args(argv)
+ if args and vars(args)[option]:
+ return (sum(vars(args)[option], []), leftover)
+ return ([], leftover)
+
+def _update_options(nvcc_options):
+ if NVCC_VERSION in ("7.0",):
+ return nvcc_options
+
+ update_options = { "relaxed-constexpr" : "expt-relaxed-constexpr" }
+ return [ update_options[opt] if opt in update_options else opt
+ for opt in nvcc_options ]
+
+def GetNvccOptions(argv):
+ """Collect the -nvcc_options values from argv.
+
+ Args:
+ argv: A list of strings, possibly the argv passed to main().
+
+ Returns:
+ 1. The string that can be passed directly to nvcc.
+ 2. The leftover options.
+ """
+
+ parser = ArgumentParser()
+ parser.add_argument('-nvcc_options', nargs='*', action='append')
+
+ args, leftover = parser.parse_known_args(argv)
+
+ if args.nvcc_options:
+ options = _update_options(sum(args.nvcc_options, []))
+ return (['--' + a for a in options], leftover)
+ return ([], leftover)
+
+
+def InvokeNvcc(argv, log=False):
+ """Call nvcc with arguments assembled from argv.
+
+ Args:
+ argv: A list of strings, possibly the argv passed to main().
+ log: True if logging is requested.
+
+ Returns:
+ The return value of calling os.system('nvcc ' + args)
+ """
+
+ src_files = [f for f in argv if
+ re.search('\.cpp$|\.cc$|\.c$|\.cxx$|\.C$', f)]
+ if len(src_files) == 0:
+ raise Error('No source files found for cuda compilation.')
+
+ out_file = [ f for f in argv if f.startswith('/Fo') ]
+ if len(out_file) != 1:
+ raise Error('Please sepecify exactly one output file for cuda compilation.')
+ out = ['-o', out_file[0][len('/Fo'):]]
+
+ nvcc_compiler_options, argv = GetNvccOptions(argv)
+
+ opt_option, argv = GetOptionValue(argv, 'O')
+ opt = ['-g', '-G']
+ if (len(opt_option) > 0 and opt_option[0] != 'd'):
+ opt = ['-O2']
+
+ include_options, argv = GetOptionValue(argv, 'I')
+ includes = ["-I " + include for include in include_options]
+
+ defines, argv = GetOptionValue(argv, 'D')
+ defines = ['-D' + define for define in defines]
+
+ undefines, argv = GetOptionValue(argv, 'U')
+ undefines = ['-U' + define for define in undefines]
+
+ # The rest of the unrecongized options should be passed to host compiler
+ host_compiler_options = [option for option in argv if option not in (src_files + out_file)]
+
+ m_options = ["-m64"]
+
+ nvccopts = ['-D_FORCE_INLINES']
+ for capability in supported_cuda_compute_capabilities:
+ capability = capability.replace('.', '')
+ nvccopts += [r'-gencode=arch=compute_%s,"code=sm_%s,compute_%s"' % (
+ capability, capability, capability)]
+ nvccopts += nvcc_compiler_options
+ nvccopts += undefines
+ nvccopts += defines
+ nvccopts += m_options
+ nvccopts += ['--compiler-options="' + " ".join(host_compiler_options) + '"']
+ nvccopts += ['-x', 'cu'] + opt + includes + out + ['-c'] + src_files
+ # If we don't specify --keep-dir, nvcc will generate intermediate files under TEMP
+ # Put them under NVCC_TEMP_DIR instead, then Bazel can ignore files under NVCC_TEMP_DIR during dependency check
+ # http://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#options-for-guiding-compiler-driver
+ # Different actions are sharing NVCC_TEMP_DIR, so we cannot remove it if the directory already exists.
+ if os.path.isfile(NVCC_TEMP_DIR):
+ os.remove(NVCC_TEMP_DIR)
+ if not os.path.exists(NVCC_TEMP_DIR):
+ os.makedirs(NVCC_TEMP_DIR)
+ nvccopts += ['--keep', '--keep-dir', NVCC_TEMP_DIR]
+ cmd = [NVCC_PATH] + nvccopts
+ if log:
+ Log(cmd)
+ proc = subprocess.Popen(cmd,
+ stdout=sys.stdout,
+ stderr=sys.stderr,
+ env=os.environ.copy(),
+ shell=True)
+ proc.wait()
+ return proc.returncode
+
+def main():
+ parser = ArgumentParser()
+ parser.add_argument('-x', nargs=1)
+ parser.add_argument('--cuda_log', action='store_true')
+ args, leftover = parser.parse_known_args(sys.argv[1:])
+
+ if args.x and args.x[0] == 'cuda':
+ if args.cuda_log: Log('-x cuda')
+ leftover = [pipes.quote(s) for s in leftover]
+ if args.cuda_log: Log('using nvcc')
+ return InvokeNvcc(leftover, log=args.cuda_log)
+
+ # Strip our flags before passing through to the CPU compiler for files which
+ # are not -x cuda. We can't just pass 'leftover' because it also strips -x.
+ # We not only want to pass -x to the CPU compiler, but also keep it in its
+ # relative location in the argv list (the compiler is actually sensitive to
+ # this).
+ cpu_compiler_flags = [flag for flag in sys.argv[1:]
+ if not flag.startswith(('--cuda_log'))
+ and not flag.startswith(('-nvcc_options'))]
+
+ return subprocess.call([CPU_COMPILER] + cpu_compiler_flags)
+
+if __name__ == '__main__':
+ sys.exit(main())
diff --git a/third_party/gpus/cuda/BUILD.windows.tpl b/third_party/gpus/cuda/BUILD.windows.tpl
new file mode 100644
index 0000000000..ff6b3cc351
--- /dev/null
+++ b/third_party/gpus/cuda/BUILD.windows.tpl
@@ -0,0 +1,163 @@
+licenses(["restricted"]) # MPL2, portions GPL v3, LGPL v3, BSD-like
+
+package(default_visibility = ["//visibility:public"])
+
+config_setting(
+ name = "using_nvcc",
+ values = {
+ "define": "using_cuda_nvcc=true",
+ },
+)
+
+config_setting(
+ name = "using_clang",
+ values = {
+ "define": "using_cuda_clang=true",
+ },
+)
+
+# Equivalent to using_clang && -c opt.
+config_setting(
+ name = "using_clang_opt",
+ values = {
+ "define": "using_cuda_clang=true",
+ "compilation_mode": "opt",
+ },
+)
+
+config_setting(
+ name = "darwin",
+ values = {"cpu": "darwin"},
+ visibility = ["//visibility:public"],
+)
+
+config_setting(
+ name = "freebsd",
+ values = {"cpu": "freebsd"},
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "cuda_headers",
+ hdrs = [
+ "cuda/cuda_config.h",
+ %{cuda_headers}
+ ],
+ includes = [
+ ".",
+ "cuda/include",
+ "cuda/include/crt",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+cc_import(
+ name = "cudart_static",
+ # /WHOLEARCHIVE:cudart_static.lib will cause a
+ # "Internal error during CImplib::EmitThunk" error.
+ # Treat this library as interface library to avoid being whole archived when
+ # linking a DLL that depends on this.
+ # TODO(pcloudy): Remove this rule after b/111278841 is resolved.
+ interface_library = "cuda/lib/%{cudart_static_lib}",
+ system_provided = 1,
+ visibility = ["//visibility:public"],
+)
+
+cc_import(
+ name = "cuda_driver",
+ interface_library = "cuda/lib/%{cuda_driver_lib}",
+ system_provided = 1,
+ visibility = ["//visibility:public"],
+)
+
+cc_import(
+ name = "cudart",
+ interface_library = "cuda/lib/%{cudart_lib}",
+ system_provided = 1,
+ visibility = ["//visibility:public"],
+)
+
+cc_import(
+ name = "cublas",
+ interface_library = "cuda/lib/%{cublas_lib}",
+ system_provided = 1,
+ visibility = ["//visibility:public"],
+)
+
+cc_import(
+ name = "cusolver",
+ interface_library = "cuda/lib/%{cusolver_lib}",
+ system_provided = 1,
+ visibility = ["//visibility:public"],
+)
+
+cc_import(
+ name = "cudnn",
+ interface_library = "cuda/lib/%{cudnn_lib}",
+ system_provided = 1,
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "cudnn_header",
+ includes = [
+ ".",
+ "cuda/include",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+cc_import(
+ name = "cufft",
+ interface_library = "cuda/lib/%{cufft_lib}",
+ system_provided = 1,
+ visibility = ["//visibility:public"],
+)
+
+cc_import(
+ name = "curand",
+ interface_library = "cuda/lib/%{curand_lib}",
+ system_provided = 1,
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "cuda",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":cublas",
+ ":cuda_headers",
+ ":cudart",
+ ":cudnn",
+ ":cufft",
+ ":curand",
+ ],
+)
+
+cc_library(
+ name = "cupti_headers",
+ hdrs = [
+ "cuda/cuda_config.h",
+ ":cuda-extras",
+ ],
+ includes = [
+ ".",
+ "cuda/extras/CUPTI/include/",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+cc_import(
+ name = "cupti_dsos",
+ interface_library = "cuda/lib/%{cupti_lib}",
+ system_provided = 1,
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "libdevice_root",
+ data = [":cuda-nvvm"],
+ visibility = ["//visibility:public"],
+)
+
+%{cuda_include_genrules}
diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl
index c90c66912d..e848fa175c 100644
--- a/third_party/gpus/cuda_configure.bzl
+++ b/third_party/gpus/cuda_configure.bzl
@@ -20,6 +20,7 @@
`/usr/local/cuda`.
* `TF_CUDA_COMPUTE_CAPABILITIES`: The CUDA compute capabilities. Default is
`3.5,5.2`.
+ * `PYTHON_BIN_PATH`: The python binary path
"""
_GCC_HOST_COMPILER_PATH = "GCC_HOST_COMPILER_PATH"
@@ -31,6 +32,7 @@ _CUDNN_INSTALL_PATH = "CUDNN_INSTALL_PATH"
_TF_CUDA_COMPUTE_CAPABILITIES = "TF_CUDA_COMPUTE_CAPABILITIES"
_TF_CUDA_CONFIG_REPO = "TF_CUDA_CONFIG_REPO"
_TF_DOWNLOAD_CLANG = "TF_DOWNLOAD_CLANG"
+_PYTHON_BIN_PATH = "PYTHON_BIN_PATH"
_DEFAULT_CUDA_VERSION = ""
_DEFAULT_CUDNN_VERSION = ""
@@ -44,12 +46,12 @@ _DEFAULT_CUDA_COMPUTE_CAPABILITIES = ["3.5", "5.2"]
# will be used. For example, when looking for the cudart libraries, the first
# attempt will be lib64/cudart inside the CUDA toolkit.
CUDA_LIB_PATHS = [
- "lib64/",
- "lib64/stubs/",
- "lib/x86_64-linux-gnu/",
- "lib/x64/",
- "lib/",
- "",
+ "lib64/",
+ "lib64/stubs/",
+ "lib/x86_64-linux-gnu/",
+ "lib/x64/",
+ "lib/",
+ "",
]
# Lookup paths for cupti.h, relative to the CUDA toolkit directory.
@@ -57,8 +59,8 @@ CUDA_LIB_PATHS = [
# On most systems, the cupti library is not installed in the same directory as
# the other CUDA libraries but rather in a special extras/CUPTI directory.
CUPTI_HEADER_PATHS = [
- "extras/CUPTI/include/",
- "include/cuda/CUPTI/",
+ "extras/CUPTI/include/",
+ "include/cuda/CUPTI/",
]
# Lookup paths for the cupti library, relative to the
@@ -66,25 +68,25 @@ CUPTI_HEADER_PATHS = [
# On most systems, the cupti library is not installed in the same directory as
# the other CUDA libraries but rather in a special extras/CUPTI directory.
CUPTI_LIB_PATHS = [
- "extras/CUPTI/lib64/",
- "lib/x86_64-linux-gnu",
- "lib64/",
- "extras/CUPTI/libx64/",
- "extras/CUPTI/lib/",
- "lib/",
+ "extras/CUPTI/lib64/",
+ "lib/x86_64-linux-gnu",
+ "lib64/",
+ "extras/CUPTI/libx64/",
+ "extras/CUPTI/lib/",
+ "lib/",
]
# Lookup paths for CUDA headers (cuda.h) relative to the CUDA toolkit directory.
CUDA_INCLUDE_PATHS = [
- "include/",
- "include/cuda/"
+ "include/",
+ "include/cuda/",
]
# Lookup paths for cudnn.h relative to the CUDNN install directory.
CUDNN_INCLUDE_PATHS = [
- "",
- "include/",
- "include/cuda/",
+ "",
+ "include/",
+ "include/cuda/",
]
# Lookup paths for NVVM libdevice relative to the CUDA directory toolkit.
@@ -92,686 +94,841 @@ CUDNN_INCLUDE_PATHS = [
# libdevice implements mathematical functions for GPU kernels, and is provided
# in NVVM bitcode (a subset of LLVM bitcode).
NVVM_LIBDEVICE_PATHS = [
- "nvvm/libdevice/",
- "share/cuda/",
+ "nvvm/libdevice/",
+ "share/cuda/",
+]
+
+# Files used to detect the NVVM libdevice path.
+NVVM_LIBDEVICE_FILES = [
+ # CUDA 9.0 has a single file.
+ "libdevice.10.bc",
+
+ # CUDA 8.0 has separate files for compute versions 2.0, 3.0, 3.5 and 5.0.
+ # Probing for one of them is sufficient.
+ "libdevice.compute_20.10.bc",
]
load("//third_party/clang_toolchain:download_clang.bzl", "download_clang")
+load(
+ "@bazel_tools//tools/cpp:lib_cc_configure.bzl",
+ "escape_string",
+ "get_env_var",
+)
+load(
+ "@bazel_tools//tools/cpp:windows_cc_configure.bzl",
+ "find_msvc_tool",
+ "find_vc_path",
+ "setup_vc_env_vars",
+)
+
+def _get_python_bin(repository_ctx):
+ """Gets the python bin path."""
+ python_bin = repository_ctx.os.environ.get(_PYTHON_BIN_PATH)
+ if python_bin != None:
+ return python_bin
+ python_bin_name = "python.exe" if _is_windows(repository_ctx) else "python"
+ python_bin_path = repository_ctx.which(python_bin_name)
+ if python_bin_path != None:
+ return str(python_bin_path)
+ auto_configure_fail("Cannot find python in PATH, please make sure " +
+ "python is installed and add its directory in PATH, or --define " +
+ "%s='/something/else'.\nPATH=%s" % (
+ _PYTHON_BIN_PATH,
+ repository_ctx.os.environ.get("PATH", ""),
+ ))
+
+def _get_nvcc_tmp_dir_for_windows(repository_ctx):
+ """Return the tmp directory for nvcc to generate intermediate source files."""
+ escaped_tmp_dir = escape_string(
+ get_env_var(repository_ctx, "TMP", "C:\\Windows\\Temp").replace("\\", "\\\\"),
+ )
+ return escaped_tmp_dir + "\\\\nvcc_inter_files_tmp_dir"
+
+def _get_msvc_compiler(repository_ctx):
+ vc_path = find_vc_path(repository_ctx)
+ return find_msvc_tool(repository_ctx, vc_path, "cl.exe").replace("\\", "/")
+
+def _get_win_cuda_defines(repository_ctx):
+ """Return CROSSTOOL defines for Windows"""
+
+ # If we are not on Windows, return empty vaules for Windows specific fields.
+ # This ensures the CROSSTOOL file parser is happy.
+ if not _is_windows(repository_ctx):
+ return {
+ "%{msvc_env_tmp}": "",
+ "%{msvc_env_path}": "",
+ "%{msvc_env_include}": "",
+ "%{msvc_env_lib}": "",
+ "%{msvc_cl_path}": "",
+ "%{msvc_ml_path}": "",
+ "%{msvc_link_path}": "",
+ "%{msvc_lib_path}": "",
+ "%{cxx_builtin_include_directory}": "",
+ }
+
+ vc_path = find_vc_path(repository_ctx)
+ if not vc_path:
+ auto_configure_fail("Visual C++ build tools not found on your machine." +
+ "Please check your installation following https://docs.bazel.build/versions/master/windows.html#using")
+ return {}
+
+ env = setup_vc_env_vars(repository_ctx, vc_path)
+ escaped_paths = escape_string(env["PATH"])
+ escaped_include_paths = escape_string(env["INCLUDE"])
+ escaped_lib_paths = escape_string(env["LIB"])
+ escaped_tmp_dir = escape_string(
+ get_env_var(repository_ctx, "TMP", "C:\\Windows\\Temp").replace("\\", "\\\\"),
+ )
+
+ msvc_cl_path = "windows/msvc_wrapper_for_nvcc.bat"
+ msvc_ml_path = find_msvc_tool(repository_ctx, vc_path, "ml64.exe").replace("\\", "/")
+ msvc_link_path = find_msvc_tool(repository_ctx, vc_path, "link.exe").replace("\\", "/")
+ msvc_lib_path = find_msvc_tool(repository_ctx, vc_path, "lib.exe").replace("\\", "/")
+
+ # nvcc will generate some temporary source files under %{nvcc_tmp_dir}
+ # The generated files are guranteed to have unique name, so they can share the same tmp directory
+ escaped_cxx_include_directories = ["cxx_builtin_include_directory: \"%s\"" % _get_nvcc_tmp_dir_for_windows(repository_ctx)]
+ for path in escaped_include_paths.split(";"):
+ if path:
+ escaped_cxx_include_directories.append("cxx_builtin_include_directory: \"%s\"" % path)
+
+ return {
+ "%{msvc_env_tmp}": escaped_tmp_dir,
+ "%{msvc_env_path}": escaped_paths,
+ "%{msvc_env_include}": escaped_include_paths,
+ "%{msvc_env_lib}": escaped_lib_paths,
+ "%{msvc_cl_path}": msvc_cl_path,
+ "%{msvc_ml_path}": msvc_ml_path,
+ "%{msvc_link_path}": msvc_link_path,
+ "%{msvc_lib_path}": msvc_lib_path,
+ "%{cxx_builtin_include_directory}": "\n".join(escaped_cxx_include_directories),
+ }
# TODO(dzc): Once these functions have been factored out of Bazel's
# cc_configure.bzl, load them from @bazel_tools instead.
# BEGIN cc_configure common functions.
def find_cc(repository_ctx):
- """Find the C++ compiler."""
- # On Windows, we use Bazel's MSVC CROSSTOOL for GPU build
- # Return a dummy value for GCC detection here to avoid error
- if _is_windows(repository_ctx):
- return "/use/--config=win-cuda --cpu=x64_windows_msvc/instead"
-
- if _use_cuda_clang(repository_ctx):
- target_cc_name = "clang"
- cc_path_envvar = _CLANG_CUDA_COMPILER_PATH
- if _flag_enabled(repository_ctx, _TF_DOWNLOAD_CLANG):
- return "extra_tools/bin/clang"
- else:
- target_cc_name = "gcc"
- cc_path_envvar = _GCC_HOST_COMPILER_PATH
- cc_name = target_cc_name
-
- if cc_path_envvar in repository_ctx.os.environ:
- cc_name_from_env = repository_ctx.os.environ[cc_path_envvar].strip()
- if cc_name_from_env:
- cc_name = cc_name_from_env
- if cc_name.startswith("/"):
- # Absolute path, maybe we should make this supported by our which function.
- return cc_name
- cc = repository_ctx.which(cc_name)
- if cc == None:
- fail(("Cannot find {}, either correct your path or set the {}" +
- " environment variable").format(target_cc_name, cc_path_envvar))
- return cc
-
+ """Find the C++ compiler."""
+ if _is_windows(repository_ctx):
+ return _get_msvc_compiler(repository_ctx)
+
+ if _use_cuda_clang(repository_ctx):
+ target_cc_name = "clang"
+ cc_path_envvar = _CLANG_CUDA_COMPILER_PATH
+ if _flag_enabled(repository_ctx, _TF_DOWNLOAD_CLANG):
+ return "extra_tools/bin/clang"
+ else:
+ target_cc_name = "gcc"
+ cc_path_envvar = _GCC_HOST_COMPILER_PATH
+ cc_name = target_cc_name
+
+ if cc_path_envvar in repository_ctx.os.environ:
+ cc_name_from_env = repository_ctx.os.environ[cc_path_envvar].strip()
+ if cc_name_from_env:
+ cc_name = cc_name_from_env
+ if cc_name.startswith("/"):
+ # Absolute path, maybe we should make this supported by our which function.
+ return cc_name
+ cc = repository_ctx.which(cc_name)
+ if cc == None:
+ fail(("Cannot find {}, either correct your path or set the {}" +
+ " environment variable").format(target_cc_name, cc_path_envvar))
+ return cc
_INC_DIR_MARKER_BEGIN = "#include <...>"
-
# OSX add " (framework directory)" at the end of line, strip it.
_OSX_FRAMEWORK_SUFFIX = " (framework directory)"
-_OSX_FRAMEWORK_SUFFIX_LEN = len(_OSX_FRAMEWORK_SUFFIX)
-def _cxx_inc_convert(path):
- """Convert path returned by cc -E xc++ in a complete path."""
- path = path.strip()
- if path.endswith(_OSX_FRAMEWORK_SUFFIX):
- path = path[:-_OSX_FRAMEWORK_SUFFIX_LEN].strip()
- return path
+_OSX_FRAMEWORK_SUFFIX_LEN = len(_OSX_FRAMEWORK_SUFFIX)
+def _cxx_inc_convert(path):
+ """Convert path returned by cc -E xc++ in a complete path."""
+ path = path.strip()
+ if path.endswith(_OSX_FRAMEWORK_SUFFIX):
+ path = path[:-_OSX_FRAMEWORK_SUFFIX_LEN].strip()
+ return path
def _normalize_include_path(repository_ctx, path):
- """Normalizes include paths before writing them to the crosstool.
+ """Normalizes include paths before writing them to the crosstool.
- If path points inside the 'crosstool' folder of the repository, a relative
- path is returned.
- If path points outside the 'crosstool' folder, an absolute path is returned.
- """
- path = str(repository_ctx.path(path))
- crosstool_folder = str(repository_ctx.path(".").get_child('crosstool'))
-
- if path.startswith(crosstool_folder):
- # We drop the path to "$REPO/crosstool" and a trailing path separator.
- return path[len(crosstool_folder)+1:]
- return path
+ If path points inside the 'crosstool' folder of the repository, a relative
+ path is returned.
+ If path points outside the 'crosstool' folder, an absolute path is returned.
+ """
+ path = str(repository_ctx.path(path))
+ crosstool_folder = str(repository_ctx.path(".").get_child("crosstool"))
+ if path.startswith(crosstool_folder):
+ # We drop the path to "$REPO/crosstool" and a trailing path separator.
+ return path[len(crosstool_folder) + 1:]
+ return path
def _get_cxx_inc_directories_impl(repository_ctx, cc, lang_is_cpp):
- """Compute the list of default C or C++ include directories."""
- if lang_is_cpp:
- lang = "c++"
- else:
- lang = "c"
- result = repository_ctx.execute([cc, "-E", "-x" + lang, "-", "-v"])
- index1 = result.stderr.find(_INC_DIR_MARKER_BEGIN)
- if index1 == -1:
- return []
- index1 = result.stderr.find("\n", index1)
- if index1 == -1:
- return []
- index2 = result.stderr.rfind("\n ")
- if index2 == -1 or index2 < index1:
- return []
- index2 = result.stderr.find("\n", index2 + 1)
- if index2 == -1:
- inc_dirs = result.stderr[index1 + 1:]
- else:
- inc_dirs = result.stderr[index1 + 1:index2].strip()
-
- return [
- _normalize_include_path(repository_ctx, _cxx_inc_convert(p))
- for p in inc_dirs.split("\n")
- ]
+ """Compute the list of default C or C++ include directories."""
+ if lang_is_cpp:
+ lang = "c++"
+ else:
+ lang = "c"
+ result = repository_ctx.execute([cc, "-E", "-x" + lang, "-", "-v"])
+ index1 = result.stderr.find(_INC_DIR_MARKER_BEGIN)
+ if index1 == -1:
+ return []
+ index1 = result.stderr.find("\n", index1)
+ if index1 == -1:
+ return []
+ index2 = result.stderr.rfind("\n ")
+ if index2 == -1 or index2 < index1:
+ return []
+ index2 = result.stderr.find("\n", index2 + 1)
+ if index2 == -1:
+ inc_dirs = result.stderr[index1 + 1:]
+ else:
+ inc_dirs = result.stderr[index1 + 1:index2].strip()
+ return [
+ _normalize_include_path(repository_ctx, _cxx_inc_convert(p))
+ for p in inc_dirs.split("\n")
+ ]
def get_cxx_inc_directories(repository_ctx, cc):
- """Compute the list of default C and C++ include directories."""
- # For some reason `clang -xc` sometimes returns include paths that are
- # different from the ones from `clang -xc++`. (Symlink and a dir)
- # So we run the compiler with both `-xc` and `-xc++` and merge resulting lists
- includes_cpp = _get_cxx_inc_directories_impl(repository_ctx, cc, True)
- includes_c = _get_cxx_inc_directories_impl(repository_ctx, cc, False)
+ """Compute the list of default C and C++ include directories."""
- includes_cpp_set = depset(includes_cpp)
- return includes_cpp + [inc for inc in includes_c
- if inc not in includes_cpp_set]
+ # For some reason `clang -xc` sometimes returns include paths that are
+ # different from the ones from `clang -xc++`. (Symlink and a dir)
+ # So we run the compiler with both `-xc` and `-xc++` and merge resulting lists
+ includes_cpp = _get_cxx_inc_directories_impl(repository_ctx, cc, True)
+ includes_c = _get_cxx_inc_directories_impl(repository_ctx, cc, False)
+ includes_cpp_set = depset(includes_cpp)
+ return includes_cpp + [
+ inc
+ for inc in includes_c
+ if inc not in includes_cpp_set
+ ]
def auto_configure_fail(msg):
- """Output failure message when cuda configuration fails."""
- red = "\033[0;31m"
- no_color = "\033[0m"
- fail("\n%sCuda Configuration Error:%s %s\n" % (red, no_color, msg))
-# END cc_configure common functions (see TODO above).
+ """Output failure message when cuda configuration fails."""
+ red = "\033[0;31m"
+ no_color = "\033[0m"
+ fail("\n%sCuda Configuration Error:%s %s\n" % (red, no_color, msg))
+# END cc_configure common functions (see TODO above).
def _host_compiler_includes(repository_ctx, cc):
- """Generates the cxx_builtin_include_directory entries for gcc inc dirs.
-
- Args:
- repository_ctx: The repository context.
- cc: The path to the gcc host compiler.
-
- Returns:
- A string containing the cxx_builtin_include_directory for each of the gcc
- host compiler include directories, which can be added to the CROSSTOOL
- file.
- """
- inc_dirs = get_cxx_inc_directories(repository_ctx, cc)
- inc_entries = []
- for inc_dir in inc_dirs:
- inc_entries.append(" cxx_builtin_include_directory: \"%s\"" % inc_dir)
- return "\n".join(inc_entries)
+ """Generates the cxx_builtin_include_directory entries for gcc inc dirs.
+
+ Args:
+ repository_ctx: The repository context.
+ cc: The path to the gcc host compiler.
+
+ Returns:
+ A string containing the cxx_builtin_include_directory for each of the gcc
+ host compiler include directories, which can be added to the CROSSTOOL
+ file.
+ """
+ inc_dirs = get_cxx_inc_directories(repository_ctx, cc)
+ inc_entries = []
+ for inc_dir in inc_dirs:
+ inc_entries.append(" cxx_builtin_include_directory: \"%s\"" % inc_dir)
+ return "\n".join(inc_entries)
def _cuda_include_path(repository_ctx, cuda_config):
- """Generates the cxx_builtin_include_directory entries for cuda inc dirs.
-
- Args:
- repository_ctx: The repository context.
- cc: The path to the gcc host compiler.
-
- Returns:
- A string containing the cxx_builtin_include_directory for each of the gcc
- host compiler include directories, which can be added to the CROSSTOOL
- file.
- """
- nvcc_path = repository_ctx.path("%s/bin/nvcc%s" %
- (cuda_config.cuda_toolkit_path,
- ".exe" if cuda_config.cpu_value == "Windows" else ""))
- result = repository_ctx.execute([nvcc_path, '-v',
- '/dev/null', '-o', '/dev/null'])
- target_dir = ""
- for one_line in result.stderr.splitlines():
- if one_line.startswith('#$ _TARGET_DIR_='):
- target_dir = (cuda_config.cuda_toolkit_path + '/' +
- one_line.replace('#$ _TARGET_DIR_=', '') + "/include")
- inc_entries = []
- if target_dir != "":
- inc_entries.append(" cxx_builtin_include_directory: \"%s\"" % target_dir)
- default_include = cuda_config.cuda_toolkit_path + '/include'
- inc_entries.append(" cxx_builtin_include_directory: \"%s\"" %
- default_include)
- return "\n".join(inc_entries)
+ """Generates the cxx_builtin_include_directory entries for cuda inc dirs.
+ Args:
+ repository_ctx: The repository context.
+ cc: The path to the gcc host compiler.
-def _enable_cuda(repository_ctx):
- if "TF_NEED_CUDA" in repository_ctx.os.environ:
- enable_cuda = repository_ctx.os.environ["TF_NEED_CUDA"].strip()
- return enable_cuda == "1"
- return False
+ Returns:
+ A string containing the cxx_builtin_include_directory for each of the gcc
+ host compiler include directories, which can be added to the CROSSTOOL
+ file.
+ """
+ nvcc_path = repository_ctx.path("%s/bin/nvcc%s" %
+ (
+ cuda_config.cuda_toolkit_path,
+ ".exe" if cuda_config.cpu_value == "Windows" else "",
+ ))
+ result = repository_ctx.execute([
+ nvcc_path,
+ "-v",
+ "/dev/null",
+ "-o",
+ "/dev/null",
+ ])
+ target_dir = ""
+ for one_line in result.stderr.splitlines():
+ if one_line.startswith("#$ _TARGET_DIR_="):
+ target_dir = (cuda_config.cuda_toolkit_path + "/" +
+ one_line.replace("#$ _TARGET_DIR_=", "") + "/include")
+ inc_entries = []
+ if target_dir != "":
+ inc_entries.append(" cxx_builtin_include_directory: \"%s\"" % target_dir)
+ default_include = cuda_config.cuda_toolkit_path + "/include"
+ inc_entries.append(" cxx_builtin_include_directory: \"%s\"" %
+ default_include)
+ return "\n".join(inc_entries)
+def _enable_cuda(repository_ctx):
+ if "TF_NEED_CUDA" in repository_ctx.os.environ:
+ enable_cuda = repository_ctx.os.environ["TF_NEED_CUDA"].strip()
+ return enable_cuda == "1"
+ return False
def _cuda_toolkit_path(repository_ctx):
- """Finds the cuda toolkit directory.
-
- Args:
- repository_ctx: The repository context.
+ """Finds the cuda toolkit directory.
- Returns:
- A speculative real path of the cuda toolkit install directory.
- """
- cuda_toolkit_path = _DEFAULT_CUDA_TOOLKIT_PATH
- if _CUDA_TOOLKIT_PATH in repository_ctx.os.environ:
- cuda_toolkit_path = repository_ctx.os.environ[_CUDA_TOOLKIT_PATH].strip()
- if not repository_ctx.path(cuda_toolkit_path).exists:
- auto_configure_fail("Cannot find cuda toolkit path.")
- return str(repository_ctx.path(cuda_toolkit_path).realpath)
+ Args:
+ repository_ctx: The repository context.
+ Returns:
+ A speculative real path of the cuda toolkit install directory.
+ """
+ cuda_toolkit_path = _DEFAULT_CUDA_TOOLKIT_PATH
+ if _CUDA_TOOLKIT_PATH in repository_ctx.os.environ:
+ cuda_toolkit_path = repository_ctx.os.environ[_CUDA_TOOLKIT_PATH].strip()
+ if not repository_ctx.path(cuda_toolkit_path).exists:
+ auto_configure_fail("Cannot find cuda toolkit path.")
+ return str(repository_ctx.path(cuda_toolkit_path).realpath)
def _cudnn_install_basedir(repository_ctx):
- """Finds the cudnn install directory."""
- cudnn_install_path = _DEFAULT_CUDNN_INSTALL_PATH
- if _CUDNN_INSTALL_PATH in repository_ctx.os.environ:
- cudnn_install_path = repository_ctx.os.environ[_CUDNN_INSTALL_PATH].strip()
- if not repository_ctx.path(cudnn_install_path).exists:
- auto_configure_fail("Cannot find cudnn install path.")
- return cudnn_install_path
-
+ """Finds the cudnn install directory."""
+ cudnn_install_path = _DEFAULT_CUDNN_INSTALL_PATH
+ if _CUDNN_INSTALL_PATH in repository_ctx.os.environ:
+ cudnn_install_path = repository_ctx.os.environ[_CUDNN_INSTALL_PATH].strip()
+ if not repository_ctx.path(cudnn_install_path).exists:
+ auto_configure_fail("Cannot find cudnn install path.")
+ return cudnn_install_path
def matches_version(environ_version, detected_version):
- """Checks whether the user-specified version matches the detected version.
-
- This function performs a weak matching so that if the user specifies only the
- major or major and minor versions, the versions are still considered matching
- if the version parts match. To illustrate:
-
- environ_version detected_version result
- -----------------------------------------
- 5.1.3 5.1.3 True
- 5.1 5.1.3 True
- 5 5.1 True
- 5.1.3 5.1 False
- 5.2.3 5.1.3 False
-
- Args:
- environ_version: The version specified by the user via environment
- variables.
- detected_version: The version autodetected from the CUDA installation on
- the system.
-
- Returns: True if user-specified version matches detected version and False
- otherwise.
- """
- environ_version_parts = environ_version.split(".")
- detected_version_parts = detected_version.split(".")
- if len(detected_version_parts) < len(environ_version_parts):
- return False
- for i, part in enumerate(detected_version_parts):
- if i >= len(environ_version_parts):
- break
- if part != environ_version_parts[i]:
- return False
- return True
-
+ """Checks whether the user-specified version matches the detected version.
+
+ This function performs a weak matching so that if the user specifies only the
+ major or major and minor versions, the versions are still considered matching
+ if the version parts match. To illustrate:
+
+ environ_version detected_version result
+ -----------------------------------------
+ 5.1.3 5.1.3 True
+ 5.1 5.1.3 True
+ 5 5.1 True
+ 5.1.3 5.1 False
+ 5.2.3 5.1.3 False
+
+ Args:
+ environ_version: The version specified by the user via environment
+ variables.
+ detected_version: The version autodetected from the CUDA installation on
+ the system.
+
+ Returns: True if user-specified version matches detected version and False
+ otherwise.
+ """
+ environ_version_parts = environ_version.split(".")
+ detected_version_parts = detected_version.split(".")
+ if len(detected_version_parts) < len(environ_version_parts):
+ return False
+ for i, part in enumerate(detected_version_parts):
+ if i >= len(environ_version_parts):
+ break
+ if part != environ_version_parts[i]:
+ return False
+ return True
_NVCC_VERSION_PREFIX = "Cuda compilation tools, release "
-
def _cuda_version(repository_ctx, cuda_toolkit_path, cpu_value):
- """Detects the version of CUDA installed on the system.
-
- Args:
- repository_ctx: The repository context.
- cuda_toolkit_path: The CUDA install directory.
-
- Returns:
- String containing the version of CUDA.
- """
- # Run nvcc --version and find the line containing the CUDA version.
- nvcc_path = repository_ctx.path("%s/bin/nvcc%s" %
- (cuda_toolkit_path,
- ".exe" if cpu_value == "Windows" else ""))
- if not nvcc_path.exists:
- auto_configure_fail("Cannot find nvcc at %s" % str(nvcc_path))
- result = repository_ctx.execute([str(nvcc_path), '--version'])
- if result.stderr:
- auto_configure_fail("Error running nvcc --version: %s" % result.stderr)
- lines = result.stdout.splitlines()
- version_line = lines[len(lines) - 1]
- if version_line.find(_NVCC_VERSION_PREFIX) == -1:
- auto_configure_fail(
- "Could not parse CUDA version from nvcc --version. Got: %s" %
- result.stdout)
-
- # Parse the CUDA version from the line containing the CUDA version.
- prefix_removed = version_line.replace(_NVCC_VERSION_PREFIX, '')
- parts = prefix_removed.split(",")
- if len(parts) != 2 or len(parts[0]) < 2:
- auto_configure_fail(
- "Could not parse CUDA version from nvcc --version. Got: %s" %
- result.stdout)
- full_version = parts[1].strip()
- if full_version.startswith('V'):
- full_version = full_version[1:]
-
- # Check whether TF_CUDA_VERSION was set by the user and fail if it does not
- # match the detected version.
- environ_version = ""
- if _TF_CUDA_VERSION in repository_ctx.os.environ:
- environ_version = repository_ctx.os.environ[_TF_CUDA_VERSION].strip()
- if environ_version and not matches_version(environ_version, full_version):
- auto_configure_fail(
- ("CUDA version detected from nvcc (%s) does not match " +
- "TF_CUDA_VERSION (%s)") % (full_version, environ_version))
-
- # We only use the version consisting of the major and minor version numbers.
- version_parts = full_version.split('.')
- if len(version_parts) < 2:
- auto_configure_fail("CUDA version detected from nvcc (%s) is incomplete.")
- if cpu_value == "Windows":
- version = "64_%s%s" % (version_parts[0], version_parts[1])
- else:
- version = "%s.%s" % (version_parts[0], version_parts[1])
- return version
+ """Detects the version of CUDA installed on the system.
+
+ Args:
+ repository_ctx: The repository context.
+ cuda_toolkit_path: The CUDA install directory.
+
+ Returns:
+ String containing the version of CUDA.
+ """
+
+ # Run nvcc --version and find the line containing the CUDA version.
+ nvcc_path = repository_ctx.path("%s/bin/nvcc%s" %
+ (
+ cuda_toolkit_path,
+ ".exe" if cpu_value == "Windows" else "",
+ ))
+ if not nvcc_path.exists:
+ auto_configure_fail("Cannot find nvcc at %s" % str(nvcc_path))
+ result = repository_ctx.execute([str(nvcc_path), "--version"])
+ if result.stderr:
+ auto_configure_fail("Error running nvcc --version: %s" % result.stderr)
+ lines = result.stdout.splitlines()
+ version_line = lines[len(lines) - 1]
+ if version_line.find(_NVCC_VERSION_PREFIX) == -1:
+ auto_configure_fail(
+ "Could not parse CUDA version from nvcc --version. Got: %s" %
+ result.stdout,
+ )
+ # Parse the CUDA version from the line containing the CUDA version.
+ prefix_removed = version_line.replace(_NVCC_VERSION_PREFIX, "")
+ parts = prefix_removed.split(",")
+ if len(parts) != 2 or len(parts[0]) < 2:
+ auto_configure_fail(
+ "Could not parse CUDA version from nvcc --version. Got: %s" %
+ result.stdout,
+ )
+ full_version = parts[1].strip()
+ if full_version.startswith("V"):
+ full_version = full_version[1:]
+
+ # Check whether TF_CUDA_VERSION was set by the user and fail if it does not
+ # match the detected version.
+ environ_version = ""
+ if _TF_CUDA_VERSION in repository_ctx.os.environ:
+ environ_version = repository_ctx.os.environ[_TF_CUDA_VERSION].strip()
+ if environ_version and not matches_version(environ_version, full_version):
+ auto_configure_fail(
+ ("CUDA version detected from nvcc (%s) does not match " +
+ "TF_CUDA_VERSION (%s)") % (full_version, environ_version),
+ )
+
+ # We only use the version consisting of the major and minor version numbers.
+ version_parts = full_version.split(".")
+ if len(version_parts) < 2:
+ auto_configure_fail("CUDA version detected from nvcc (%s) is incomplete.")
+ if cpu_value == "Windows":
+ version = "64_%s%s" % (version_parts[0], version_parts[1])
+ else:
+ version = "%s.%s" % (version_parts[0], version_parts[1])
+ return version
_DEFINE_CUDNN_MAJOR = "#define CUDNN_MAJOR"
_DEFINE_CUDNN_MINOR = "#define CUDNN_MINOR"
_DEFINE_CUDNN_PATCHLEVEL = "#define CUDNN_PATCHLEVEL"
-
def find_cuda_define(repository_ctx, header_dir, header_file, define):
- """Returns the value of a #define in a header file.
-
- Greps through a header file and returns the value of the specified #define.
- If the #define is not found, then raise an error.
-
- Args:
- repository_ctx: The repository context.
- header_dir: The directory containing the header file.
- header_file: The header file name.
- define: The #define to search for.
-
- Returns:
- The value of the #define found in the header.
- """
- # Confirm location of the header and grep for the line defining the macro.
- h_path = repository_ctx.path("%s/%s" % (header_dir, header_file))
- if not h_path.exists:
- auto_configure_fail("Cannot find %s at %s" % (header_file, str(h_path)))
- result = repository_ctx.execute(
- # Grep one more lines as some #defines are splitted into two lines.
- ["grep", "--color=never", "-A1", "-E", define, str(h_path)])
- if result.stderr:
- auto_configure_fail("Error reading %s: %s" % (str(h_path), result.stderr))
-
- # Parse the version from the line defining the macro.
- if result.stdout.find(define) == -1:
- auto_configure_fail("Cannot find line containing '%s' in %s" %
- (define, h_path))
- # Split results to lines
- lines = result.stdout.split('\n')
- num_lines = len(lines)
- for l in range(num_lines):
- line = lines[l]
- if define in line: # Find the line with define
- version = line
- if l != num_lines-1 and line[-1] == '\\': # Add next line, if multiline
- version = version[:-1] + lines[l+1]
- break
- # Remove any comments
- version = version.split("//")[0]
- # Remove define name
- version = version.replace(define, "").strip()
- # Remove the code after the version number.
- version_end = version.find(" ")
- if version_end != -1:
- if version_end == 0:
- auto_configure_fail(
- "Cannot extract the version from line containing '%s' in %s" %
- (define, str(h_path)))
- version = version[:version_end].strip()
- return version
+ """Returns the value of a #define in a header file.
+
+ Greps through a header file and returns the value of the specified #define.
+ If the #define is not found, then raise an error.
+ Args:
+ repository_ctx: The repository context.
+ header_dir: The directory containing the header file.
+ header_file: The header file name.
+ define: The #define to search for.
+
+ Returns:
+ The value of the #define found in the header.
+ """
+
+ # Confirm location of the header and grep for the line defining the macro.
+ h_path = repository_ctx.path("%s/%s" % (header_dir, header_file))
+ if not h_path.exists:
+ auto_configure_fail("Cannot find %s at %s" % (header_file, str(h_path)))
+ result = repository_ctx.execute(
+ # Grep one more lines as some #defines are splitted into two lines.
+ ["grep", "--color=never", "-A1", "-E", define, str(h_path)],
+ )
+ if result.stderr:
+ auto_configure_fail("Error reading %s: %s" % (str(h_path), result.stderr))
+
+ # Parse the version from the line defining the macro.
+ if result.stdout.find(define) == -1:
+ auto_configure_fail("Cannot find line containing '%s' in %s" %
+ (define, h_path))
+
+ # Split results to lines
+ lines = result.stdout.split("\n")
+ num_lines = len(lines)
+ for l in range(num_lines):
+ line = lines[l]
+ if define in line: # Find the line with define
+ version = line
+ if l != num_lines - 1 and line[-1] == "\\": # Add next line, if multiline
+ version = version[:-1] + lines[l + 1]
+ break
+
+ # Remove any comments
+ version = version.split("//")[0]
+
+ # Remove define name
+ version = version.replace(define, "").strip()
+
+ # Remove the code after the version number.
+ version_end = version.find(" ")
+ if version_end != -1:
+ if version_end == 0:
+ auto_configure_fail(
+ "Cannot extract the version from line containing '%s' in %s" %
+ (define, str(h_path)),
+ )
+ version = version[:version_end].strip()
+ return version
def _cudnn_version(repository_ctx, cudnn_install_basedir, cpu_value):
- """Detects the version of cuDNN installed on the system.
-
- Args:
- repository_ctx: The repository context.
- cpu_value: The name of the host operating system.
- cudnn_install_basedir: The cuDNN install directory.
-
- Returns:
- A string containing the version of cuDNN.
- """
- cudnn_header_dir = _find_cudnn_header_dir(repository_ctx,
- cudnn_install_basedir)
- major_version = find_cuda_define(
- repository_ctx, cudnn_header_dir, "cudnn.h", _DEFINE_CUDNN_MAJOR)
- minor_version = find_cuda_define(
- repository_ctx, cudnn_header_dir, "cudnn.h", _DEFINE_CUDNN_MINOR)
- patch_version = find_cuda_define(
- repository_ctx, cudnn_header_dir, "cudnn.h", _DEFINE_CUDNN_PATCHLEVEL)
- full_version = "%s.%s.%s" % (major_version, minor_version, patch_version)
-
- # Check whether TF_CUDNN_VERSION was set by the user and fail if it does not
- # match the detected version.
- environ_version = ""
- if _TF_CUDNN_VERSION in repository_ctx.os.environ:
- environ_version = repository_ctx.os.environ[_TF_CUDNN_VERSION].strip()
- if environ_version and not matches_version(environ_version, full_version):
- cudnn_h_path = repository_ctx.path("%s/include/cudnn.h" %
- cudnn_install_basedir)
- auto_configure_fail(
- ("cuDNN version detected from %s (%s) does not match " +
- "TF_CUDNN_VERSION (%s)") %
- (str(cudnn_h_path), full_version, environ_version))
-
- # We only use the major version since we use the libcudnn libraries that are
- # only versioned with the major version (e.g. libcudnn.so.5).
- version = major_version
- if cpu_value == "Windows":
- version = "64_" + version
- return version
+ """Detects the version of cuDNN installed on the system.
+ Args:
+ repository_ctx: The repository context.
+ cpu_value: The name of the host operating system.
+ cudnn_install_basedir: The cuDNN install directory.
-def _compute_capabilities(repository_ctx):
- """Returns a list of strings representing cuda compute capabilities."""
- if _TF_CUDA_COMPUTE_CAPABILITIES not in repository_ctx.os.environ:
- return _DEFAULT_CUDA_COMPUTE_CAPABILITIES
- capabilities_str = repository_ctx.os.environ[_TF_CUDA_COMPUTE_CAPABILITIES]
- capabilities = capabilities_str.split(",")
- for capability in capabilities:
- # Workaround for Skylark's lack of support for regex. This check should
- # be equivalent to checking:
- # if re.match("[0-9]+.[0-9]+", capability) == None:
- parts = capability.split(".")
- if len(parts) != 2 or not parts[0].isdigit() or not parts[1].isdigit():
- auto_configure_fail("Invalid compute capability: %s" % capability)
- return capabilities
+ Returns:
+ A string containing the version of cuDNN.
+ """
+ cudnn_header_dir = _find_cudnn_header_dir(
+ repository_ctx,
+ cudnn_install_basedir,
+ )
+ major_version = find_cuda_define(
+ repository_ctx,
+ cudnn_header_dir,
+ "cudnn.h",
+ _DEFINE_CUDNN_MAJOR,
+ )
+ minor_version = find_cuda_define(
+ repository_ctx,
+ cudnn_header_dir,
+ "cudnn.h",
+ _DEFINE_CUDNN_MINOR,
+ )
+ patch_version = find_cuda_define(
+ repository_ctx,
+ cudnn_header_dir,
+ "cudnn.h",
+ _DEFINE_CUDNN_PATCHLEVEL,
+ )
+ full_version = "%s.%s.%s" % (major_version, minor_version, patch_version)
+
+ # Check whether TF_CUDNN_VERSION was set by the user and fail if it does not
+ # match the detected version.
+ environ_version = ""
+ if _TF_CUDNN_VERSION in repository_ctx.os.environ:
+ environ_version = repository_ctx.os.environ[_TF_CUDNN_VERSION].strip()
+ if environ_version and not matches_version(environ_version, full_version):
+ cudnn_h_path = repository_ctx.path("%s/include/cudnn.h" %
+ cudnn_install_basedir)
+ auto_configure_fail(
+ ("cuDNN version detected from %s (%s) does not match " +
+ "TF_CUDNN_VERSION (%s)") %
+ (str(cudnn_h_path), full_version, environ_version),
+ )
+ # We only use the major version since we use the libcudnn libraries that are
+ # only versioned with the major version (e.g. libcudnn.so.5).
+ version = major_version
+ if cpu_value == "Windows":
+ version = "64_" + version
+ return version
-def get_cpu_value(repository_ctx):
- """Returns the name of the host operating system.
+def _compute_capabilities(repository_ctx):
+ """Returns a list of strings representing cuda compute capabilities."""
+ if _TF_CUDA_COMPUTE_CAPABILITIES not in repository_ctx.os.environ:
+ return _DEFAULT_CUDA_COMPUTE_CAPABILITIES
+ capabilities_str = repository_ctx.os.environ[_TF_CUDA_COMPUTE_CAPABILITIES]
+ capabilities = capabilities_str.split(",")
+ for capability in capabilities:
+ # Workaround for Skylark's lack of support for regex. This check should
+ # be equivalent to checking:
+ # if re.match("[0-9]+.[0-9]+", capability) == None:
+ parts = capability.split(".")
+ if len(parts) != 2 or not parts[0].isdigit() or not parts[1].isdigit():
+ auto_configure_fail("Invalid compute capability: %s" % capability)
+ return capabilities
- Args:
- repository_ctx: The repository context.
+def get_cpu_value(repository_ctx):
+ """Returns the name of the host operating system.
- Returns:
- A string containing the name of the host operating system.
- """
- os_name = repository_ctx.os.name.lower()
- if os_name.startswith("mac os"):
- return "Darwin"
- if os_name.find("windows") != -1:
- return "Windows"
- result = repository_ctx.execute(["uname", "-s"])
- return result.stdout.strip()
+ Args:
+ repository_ctx: The repository context.
+ Returns:
+ A string containing the name of the host operating system.
+ """
+ os_name = repository_ctx.os.name.lower()
+ if os_name.startswith("mac os"):
+ return "Darwin"
+ if os_name.find("windows") != -1:
+ return "Windows"
+ result = repository_ctx.execute(["uname", "-s"])
+ return result.stdout.strip()
def _is_windows(repository_ctx):
- """Returns true if the host operating system is windows."""
- return get_cpu_value(repository_ctx) == "Windows"
-
-def _lib_name(lib, cpu_value, version="", static=False):
- """Constructs the platform-specific name of a library.
-
- Args:
- lib: The name of the library, such as "cudart"
- cpu_value: The name of the host operating system.
- version: The version of the library.
- static: True the library is static or False if it is a shared object.
-
- Returns:
- The platform-specific name of the library.
- """
- if cpu_value in ("Linux", "FreeBSD"):
- if static:
- return "lib%s.a" % lib
- else:
- if version:
- version = ".%s" % version
- return "lib%s.so%s" % (lib, version)
- elif cpu_value == "Windows":
- return "%s.lib" % lib
- elif cpu_value == "Darwin":
- if static:
- return "lib%s.a" % lib
- else:
- if version:
- version = ".%s" % version
- return "lib%s%s.dylib" % (lib, version)
- else:
- auto_configure_fail("Invalid cpu_value: %s" % cpu_value)
-
-
-def _find_cuda_lib(lib, repository_ctx, cpu_value, basedir, version="",
- static=False):
- """Finds the given CUDA or cuDNN library on the system.
-
- Args:
- lib: The name of the library, such as "cudart"
- repository_ctx: The repository context.
- cpu_value: The name of the host operating system.
- basedir: The install directory of CUDA or cuDNN.
- version: The version of the library.
- static: True if static library, False if shared object.
-
- Returns:
- Returns a struct with the following fields:
- file_name: The basename of the library found on the system.
- path: The full path to the library.
- """
- file_name = _lib_name(lib, cpu_value, version, static)
- for relative_path in CUDA_LIB_PATHS:
- path = repository_ctx.path("%s/%s%s" % (basedir, relative_path, file_name))
- if path.exists:
- return struct(file_name=file_name, path=str(path.realpath))
- auto_configure_fail("Cannot find cuda library %s" % file_name)
+ """Returns true if the host operating system is windows."""
+ return get_cpu_value(repository_ctx) == "Windows"
+def _lib_name(lib, cpu_value, version = "", static = False):
+ """Constructs the platform-specific name of a library.
-def _find_cupti_header_dir(repository_ctx, cuda_config):
- """Returns the path to the directory containing cupti.h
+ Args:
+ lib: The name of the library, such as "cudart"
+ cpu_value: The name of the host operating system.
+ version: The version of the library.
+ static: True the library is static or False if it is a shared object.
+
+ Returns:
+ The platform-specific name of the library.
+ """
+ if cpu_value in ("Linux", "FreeBSD"):
+ if static:
+ return "lib%s.a" % lib
+ else:
+ if version:
+ version = ".%s" % version
+ return "lib%s.so%s" % (lib, version)
+ elif cpu_value == "Windows":
+ return "%s.lib" % lib
+ elif cpu_value == "Darwin":
+ if static:
+ return "lib%s.a" % lib
+ elif version:
+ version = ".%s" % version
+ return "lib%s%s.dylib" % (lib, version)
+ else:
+ auto_configure_fail("Invalid cpu_value: %s" % cpu_value)
+
+def _find_cuda_lib(
+ lib,
+ repository_ctx,
+ cpu_value,
+ basedir,
+ version = "",
+ static = False):
+ """Finds the given CUDA or cuDNN library on the system.
+
+ Args:
+ lib: The name of the library, such as "cudart"
+ repository_ctx: The repository context.
+ cpu_value: The name of the host operating system.
+ basedir: The install directory of CUDA or cuDNN.
+ version: The version of the library.
+ static: True if static library, False if shared object.
+
+ Returns:
+ Returns a struct with the following fields:
+ file_name: The basename of the library found on the system.
+ path: The full path to the library.
+ """
+ file_name = _lib_name(lib, cpu_value, version, static)
+ for relative_path in CUDA_LIB_PATHS:
+ path = repository_ctx.path("%s/%s%s" % (basedir, relative_path, file_name))
+ if path.exists:
+ return struct(file_name = file_name, path = str(path.realpath))
+ auto_configure_fail("Cannot find cuda library %s" % file_name)
- On most systems, the cupti library is not installed in the same directory as
- the other CUDA libraries but rather in a special extras/CUPTI directory.
+def _find_cupti_header_dir(repository_ctx, cuda_config):
+ """Returns the path to the directory containing cupti.h
- Args:
- repository_ctx: The repository context.
- cuda_config: The CUDA config as returned by _get_cuda_config
+ On most systems, the cupti library is not installed in the same directory as
+ the other CUDA libraries but rather in a special extras/CUPTI directory.
- Returns:
- The path of the directory containing the cupti header.
- """
- cuda_toolkit_path = cuda_config.cuda_toolkit_path
- for relative_path in CUPTI_HEADER_PATHS:
- if repository_ctx.path("%s/%scupti.h" % (cuda_toolkit_path, relative_path)).exists:
- return ("%s/%s" % (cuda_toolkit_path, relative_path))[:-1]
- auto_configure_fail("Cannot find cupti.h under %s" % ", ".join([cuda_toolkit_path + "/" + s for s in CUPTI_HEADER_PATHS]))
+ Args:
+ repository_ctx: The repository context.
+ cuda_config: The CUDA config as returned by _get_cuda_config
+ Returns:
+ The path of the directory containing the cupti header.
+ """
+ cuda_toolkit_path = cuda_config.cuda_toolkit_path
+ for relative_path in CUPTI_HEADER_PATHS:
+ if repository_ctx.path("%s/%scupti.h" % (cuda_toolkit_path, relative_path)).exists:
+ return ("%s/%s" % (cuda_toolkit_path, relative_path))[:-1]
+ auto_configure_fail("Cannot find cupti.h under %s" % ", ".join([cuda_toolkit_path + "/" + s for s in CUPTI_HEADER_PATHS]))
def _find_cupti_lib(repository_ctx, cuda_config):
- """Finds the cupti library on the system.
-
- On most systems, the cupti library is not installed in the same directory as
- the other CUDA libraries but rather in a special extras/CUPTI directory.
-
- Args:
- repository_ctx: The repository context.
- cuda_config: The cuda configuration as returned by _get_cuda_config.
-
- Returns:
- Returns a struct with the following fields:
- file_name: The basename of the library found on the system.
- path: The full path to the library.
- """
- file_name = _lib_name("cupti", cuda_config.cpu_value,
- cuda_config.cuda_version)
- cuda_toolkit_path = cuda_config.cuda_toolkit_path
- for relative_path in CUPTI_LIB_PATHS:
- path = repository_ctx.path(
- "%s/%s%s" % (cuda_toolkit_path, relative_path, file_name))
- if path.exists:
- return struct(file_name=file_name, path=str(path.realpath))
-
- auto_configure_fail("Cannot find cupti library %s" % file_name)
+ """Finds the cupti library on the system.
+
+ On most systems, the cupti library is not installed in the same directory as
+ the other CUDA libraries but rather in a special extras/CUPTI directory.
+
+ Args:
+ repository_ctx: The repository context.
+ cuda_config: The cuda configuration as returned by _get_cuda_config.
+
+ Returns:
+ Returns a struct with the following fields:
+ file_name: The basename of the library found on the system.
+ path: The full path to the library.
+ """
+ file_name = _lib_name(
+ "cupti",
+ cuda_config.cpu_value,
+ cuda_config.cuda_version,
+ )
+ cuda_toolkit_path = cuda_config.cuda_toolkit_path
+ for relative_path in CUPTI_LIB_PATHS:
+ path = repository_ctx.path(
+ "%s/%s%s" % (cuda_toolkit_path, relative_path, file_name),
+ )
+ if path.exists:
+ return struct(file_name = file_name, path = str(path.realpath))
+
+ auto_configure_fail("Cannot find cupti library %s" % file_name)
def _find_libs(repository_ctx, cuda_config):
- """Returns the CUDA and cuDNN libraries on the system.
-
- Args:
- repository_ctx: The repository context.
- cuda_config: The CUDA config as returned by _get_cuda_config
-
- Returns:
- Map of library names to structs of filename and path.
- """
- cpu_value = cuda_config.cpu_value
- return {
- "cuda": _find_cuda_lib("cuda", repository_ctx, cpu_value, cuda_config.cuda_toolkit_path),
- "cudart": _find_cuda_lib(
- "cudart", repository_ctx, cpu_value, cuda_config.cuda_toolkit_path,
- cuda_config.cuda_version),
- "cudart_static": _find_cuda_lib(
- "cudart_static", repository_ctx, cpu_value,
- cuda_config.cuda_toolkit_path, cuda_config.cuda_version, static=True),
- "cublas": _find_cuda_lib(
- "cublas", repository_ctx, cpu_value, cuda_config.cuda_toolkit_path,
- cuda_config.cuda_version),
- "cusolver": _find_cuda_lib(
- "cusolver", repository_ctx, cpu_value, cuda_config.cuda_toolkit_path,
- cuda_config.cuda_version),
- "curand": _find_cuda_lib(
- "curand", repository_ctx, cpu_value, cuda_config.cuda_toolkit_path,
- cuda_config.cuda_version),
- "cufft": _find_cuda_lib(
- "cufft", repository_ctx, cpu_value, cuda_config.cuda_toolkit_path,
- cuda_config.cuda_version),
- "cudnn": _find_cuda_lib(
- "cudnn", repository_ctx, cpu_value, cuda_config.cudnn_install_basedir,
- cuda_config.cudnn_version),
- "cupti": _find_cupti_lib(repository_ctx, cuda_config)
- }
+ """Returns the CUDA and cuDNN libraries on the system.
+ Args:
+ repository_ctx: The repository context.
+ cuda_config: The CUDA config as returned by _get_cuda_config
-def _find_cuda_include_path(repository_ctx, cuda_config):
- """Returns the path to the directory containing cuda.h
+ Returns:
+ Map of library names to structs of filename and path.
+ """
+ cpu_value = cuda_config.cpu_value
+ return {
+ "cuda": _find_cuda_lib("cuda", repository_ctx, cpu_value, cuda_config.cuda_toolkit_path),
+ "cudart": _find_cuda_lib(
+ "cudart",
+ repository_ctx,
+ cpu_value,
+ cuda_config.cuda_toolkit_path,
+ cuda_config.cuda_version,
+ ),
+ "cudart_static": _find_cuda_lib(
+ "cudart_static",
+ repository_ctx,
+ cpu_value,
+ cuda_config.cuda_toolkit_path,
+ cuda_config.cuda_version,
+ static = True,
+ ),
+ "cublas": _find_cuda_lib(
+ "cublas",
+ repository_ctx,
+ cpu_value,
+ cuda_config.cuda_toolkit_path,
+ cuda_config.cuda_version,
+ ),
+ "cusolver": _find_cuda_lib(
+ "cusolver",
+ repository_ctx,
+ cpu_value,
+ cuda_config.cuda_toolkit_path,
+ cuda_config.cuda_version,
+ ),
+ "curand": _find_cuda_lib(
+ "curand",
+ repository_ctx,
+ cpu_value,
+ cuda_config.cuda_toolkit_path,
+ cuda_config.cuda_version,
+ ),
+ "cufft": _find_cuda_lib(
+ "cufft",
+ repository_ctx,
+ cpu_value,
+ cuda_config.cuda_toolkit_path,
+ cuda_config.cuda_version,
+ ),
+ "cudnn": _find_cuda_lib(
+ "cudnn",
+ repository_ctx,
+ cpu_value,
+ cuda_config.cudnn_install_basedir,
+ cuda_config.cudnn_version,
+ ),
+ "cupti": _find_cupti_lib(repository_ctx, cuda_config),
+ }
- Args:
- repository_ctx: The repository context.
- cuda_config: The CUDA config as returned by _get_cuda_config
+def _find_cuda_include_path(repository_ctx, cuda_config):
+ """Returns the path to the directory containing cuda.h
- Returns:
- The path of the directory containing the CUDA headers.
- """
- cuda_toolkit_path = cuda_config.cuda_toolkit_path
- for relative_path in CUDA_INCLUDE_PATHS:
- if repository_ctx.path("%s/%scuda.h" % (cuda_toolkit_path, relative_path)).exists:
- return ("%s/%s" % (cuda_toolkit_path, relative_path))[:-1]
- auto_configure_fail("Cannot find cuda.h under %s" % cuda_toolkit_path)
+ Args:
+ repository_ctx: The repository context.
+ cuda_config: The CUDA config as returned by _get_cuda_config
+ Returns:
+ The path of the directory containing the CUDA headers.
+ """
+ cuda_toolkit_path = cuda_config.cuda_toolkit_path
+ for relative_path in CUDA_INCLUDE_PATHS:
+ if repository_ctx.path("%s/%scuda.h" % (cuda_toolkit_path, relative_path)).exists:
+ return ("%s/%s" % (cuda_toolkit_path, relative_path))[:-1]
+ auto_configure_fail("Cannot find cuda.h under %s" % cuda_toolkit_path)
def _find_cudnn_header_dir(repository_ctx, cudnn_install_basedir):
- """Returns the path to the directory containing cudnn.h
-
- Args:
- repository_ctx: The repository context.
- cudnn_install_basedir: The cudnn install directory as returned by
- _cudnn_install_basedir.
+ """Returns the path to the directory containing cudnn.h
- Returns:
- The path of the directory containing the cudnn header.
- """
- for relative_path in CUDA_INCLUDE_PATHS:
- if repository_ctx.path("%s/%scudnn.h" % (cudnn_install_basedir, relative_path)).exists:
- return ("%s/%s" % (cudnn_install_basedir, relative_path))[:-1]
- if repository_ctx.path("/usr/include/cudnn.h").exists:
- return "/usr/include"
- auto_configure_fail("Cannot find cudnn.h under %s" % cudnn_install_basedir)
+ Args:
+ repository_ctx: The repository context.
+ cudnn_install_basedir: The cudnn install directory as returned by
+ _cudnn_install_basedir.
+ Returns:
+ The path of the directory containing the cudnn header.
+ """
+ for relative_path in CUDA_INCLUDE_PATHS:
+ if repository_ctx.path("%s/%scudnn.h" % (cudnn_install_basedir, relative_path)).exists:
+ return ("%s/%s" % (cudnn_install_basedir, relative_path))[:-1]
+ if repository_ctx.path("/usr/include/cudnn.h").exists:
+ return "/usr/include"
+ auto_configure_fail("Cannot find cudnn.h under %s" % cudnn_install_basedir)
def _find_nvvm_libdevice_dir(repository_ctx, cuda_config):
- """Returns the path to the directory containing libdevice in bitcode format.
+ """Returns the path to the directory containing libdevice in bitcode format.
- Args:
- repository_ctx: The repository context.
- cuda_config: The CUDA config as returned by _get_cuda_config
-
- Returns:
- The path of the directory containing the CUDA headers.
- """
- cuda_toolkit_path = cuda_config.cuda_toolkit_path
- for relative_path in NVVM_LIBDEVICE_PATHS:
- if repository_ctx.path("%s/%slibdevice.10.bc" % (cuda_toolkit_path, relative_path)).exists:
- return ("%s/%s" % (cuda_toolkit_path, relative_path))[:-1]
- auto_configure_fail("Cannot find libdevice.10.bc under %s" % cuda_toolkit_path)
+ Args:
+ repository_ctx: The repository context.
+ cuda_config: The CUDA config as returned by _get_cuda_config
+ Returns:
+ The path of the directory containing the CUDA headers.
+ """
+ cuda_toolkit_path = cuda_config.cuda_toolkit_path
+ for libdevice_file in NVVM_LIBDEVICE_FILES:
+ for relative_path in NVVM_LIBDEVICE_PATHS:
+ if repository_ctx.path("%s/%s%s" % (cuda_toolkit_path, relative_path, libdevice_file)).exists:
+ return ("%s/%s" % (cuda_toolkit_path, relative_path))[:-1]
+ auto_configure_fail("Cannot find libdevice*.bc files under %s" % cuda_toolkit_path)
def _cudart_static_linkopt(cpu_value):
- """Returns additional platform-specific linkopts for cudart."""
- return "" if cpu_value == "Darwin" else "\"-lrt\","
+ """Returns additional platform-specific linkopts for cudart."""
+ return "" if cpu_value == "Darwin" else "\"-lrt\","
def _get_cuda_config(repository_ctx):
- """Detects and returns information about the CUDA installation on the system.
-
- Args:
- repository_ctx: The repository context.
-
- Returns:
- A struct containing the following fields:
- cuda_toolkit_path: The CUDA toolkit installation directory.
- cudnn_install_basedir: The cuDNN installation directory.
- cuda_version: The version of CUDA on the system.
- cudnn_version: The version of cuDNN on the system.
- compute_capabilities: A list of the system's CUDA compute capabilities.
- cpu_value: The name of the host operating system.
- """
- cpu_value = get_cpu_value(repository_ctx)
- cuda_toolkit_path = _cuda_toolkit_path(repository_ctx)
- cuda_version = _cuda_version(repository_ctx, cuda_toolkit_path, cpu_value)
- cudnn_install_basedir = _cudnn_install_basedir(repository_ctx)
- cudnn_version = _cudnn_version(repository_ctx, cudnn_install_basedir, cpu_value)
- return struct(
- cuda_toolkit_path = cuda_toolkit_path,
- cudnn_install_basedir = cudnn_install_basedir,
- cuda_version = cuda_version,
- cudnn_version = cudnn_version,
- compute_capabilities = _compute_capabilities(repository_ctx),
- cpu_value = cpu_value)
-
-
-def _tpl(repository_ctx, tpl, substitutions={}, out=None):
- if not out:
- out = tpl.replace(":", "/")
- repository_ctx.template(
- out,
- Label("//third_party/gpus/%s.tpl" % tpl),
- substitutions)
-
+ """Detects and returns information about the CUDA installation on the system.
+
+ Args:
+ repository_ctx: The repository context.
+
+ Returns:
+ A struct containing the following fields:
+ cuda_toolkit_path: The CUDA toolkit installation directory.
+ cudnn_install_basedir: The cuDNN installation directory.
+ cuda_version: The version of CUDA on the system.
+ cudnn_version: The version of cuDNN on the system.
+ compute_capabilities: A list of the system's CUDA compute capabilities.
+ cpu_value: The name of the host operating system.
+ """
+ cpu_value = get_cpu_value(repository_ctx)
+ cuda_toolkit_path = _cuda_toolkit_path(repository_ctx)
+ cuda_version = _cuda_version(repository_ctx, cuda_toolkit_path, cpu_value)
+ cudnn_install_basedir = _cudnn_install_basedir(repository_ctx)
+ cudnn_version = _cudnn_version(repository_ctx, cudnn_install_basedir, cpu_value)
+ return struct(
+ cuda_toolkit_path = cuda_toolkit_path,
+ cudnn_install_basedir = cudnn_install_basedir,
+ cuda_version = cuda_version,
+ cudnn_version = cudnn_version,
+ compute_capabilities = _compute_capabilities(repository_ctx),
+ cpu_value = cpu_value,
+ )
+
+def _tpl(repository_ctx, tpl, substitutions = {}, out = None):
+ if not out:
+ out = tpl.replace(":", "/")
+ repository_ctx.template(
+ out,
+ Label("//third_party/gpus/%s.tpl" % tpl),
+ substitutions,
+ )
def _file(repository_ctx, label):
- repository_ctx.template(
- label.replace(":", "/"),
- Label("//third_party/gpus/%s.tpl" % label),
- {})
-
+ repository_ctx.template(
+ label.replace(":", "/"),
+ Label("//third_party/gpus/%s.tpl" % label),
+ {},
+ )
_DUMMY_CROSSTOOL_BZL_FILE = """
def error_gpu_disabled():
@@ -792,379 +949,498 @@ def error_gpu_disabled():
)
"""
-
_DUMMY_CROSSTOOL_BUILD_FILE = """
load("//crosstool:error_gpu_disabled.bzl", "error_gpu_disabled")
error_gpu_disabled()
"""
-
def _create_dummy_repository(repository_ctx):
- cpu_value = get_cpu_value(repository_ctx)
-
- # Set up BUILD file for cuda/.
- _tpl(repository_ctx, "cuda:build_defs.bzl",
- {
- "%{cuda_is_configured}": "False",
- "%{cuda_extra_copts}": "[]",
- })
- _tpl(repository_ctx, "cuda:BUILD",
- {
- "%{cuda_driver_lib}": _lib_name("cuda", cpu_value),
- "%{cudart_static_lib}": _lib_name("cudart_static", cpu_value,
- static=True),
- "%{cudart_static_linkopt}": _cudart_static_linkopt(cpu_value),
- "%{cudart_lib}": _lib_name("cudart", cpu_value),
- "%{cublas_lib}": _lib_name("cublas", cpu_value),
- "%{cusolver_lib}": _lib_name("cusolver", cpu_value),
- "%{cudnn_lib}": _lib_name("cudnn", cpu_value),
- "%{cufft_lib}": _lib_name("cufft", cpu_value),
- "%{curand_lib}": _lib_name("curand", cpu_value),
- "%{cupti_lib}": _lib_name("cupti", cpu_value),
- "%{cuda_include_genrules}": '',
- "%{cuda_headers}": '',
- })
-
- # Create dummy files for the CUDA toolkit since they are still required by
- # tensorflow/core/platform/default/build_config:cuda.
- repository_ctx.file("cuda/cuda/include/cuda.h", "")
- repository_ctx.file("cuda/cuda/include/cublas.h", "")
- repository_ctx.file("cuda/cuda/include/cudnn.h", "")
- repository_ctx.file("cuda/cuda/extras/CUPTI/include/cupti.h", "")
- repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cuda", cpu_value))
- repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cudart", cpu_value))
- repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cudart_static", cpu_value))
- repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cublas", cpu_value))
- repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cusolver", cpu_value))
- repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cudnn", cpu_value))
- repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("curand", cpu_value))
- repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cufft", cpu_value))
- repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cupti", cpu_value))
-
- # Set up cuda_config.h, which is used by
- # tensorflow/stream_executor/dso_loader.cc.
- _tpl(repository_ctx, "cuda:cuda_config.h",
- {
- "%{cuda_version}": _DEFAULT_CUDA_VERSION,
- "%{cudnn_version}": _DEFAULT_CUDNN_VERSION,
- "%{cuda_compute_capabilities}": ",".join([
- "CudaVersion(\"%s\")" % c
- for c in _DEFAULT_CUDA_COMPUTE_CAPABILITIES]),
- "%{cuda_toolkit_path}": _DEFAULT_CUDA_TOOLKIT_PATH,
- }, "cuda/cuda/cuda_config.h")
-
- # If cuda_configure is not configured to build with GPU support, and the user
- # attempts to build with --config=cuda, add a dummy build rule to intercept
- # this and fail with an actionable error message.
- repository_ctx.file("crosstool/error_gpu_disabled.bzl",
- _DUMMY_CROSSTOOL_BZL_FILE)
- repository_ctx.file("crosstool/BUILD", _DUMMY_CROSSTOOL_BUILD_FILE)
-
-
-def _execute(repository_ctx, cmdline, error_msg=None, error_details=None,
- empty_stdout_fine=False):
- """Executes an arbitrary shell command.
-
- Args:
- repository_ctx: the repository_ctx object
- cmdline: list of strings, the command to execute
- error_msg: string, a summary of the error if the command fails
- error_details: string, details about the error or steps to fix it
- empty_stdout_fine: bool, if True, an empty stdout result is fine, otherwise
- it's an error
- Return:
- the result of repository_ctx.execute(cmdline)
- """
- result = repository_ctx.execute(cmdline)
- if result.stderr or not (empty_stdout_fine or result.stdout):
- auto_configure_fail(
- "\n".join([
- error_msg.strip() if error_msg else "Repository command failed",
- result.stderr.strip(),
- error_details if error_details else ""]))
- return result
-
+ cpu_value = get_cpu_value(repository_ctx)
+
+ # Set up BUILD file for cuda/.
+ _tpl(
+ repository_ctx,
+ "cuda:build_defs.bzl",
+ {
+ "%{cuda_is_configured}": "False",
+ "%{cuda_extra_copts}": "[]",
+ },
+ )
+ _tpl(
+ repository_ctx,
+ "cuda:BUILD",
+ {
+ "%{cuda_driver_lib}": _lib_name("cuda", cpu_value),
+ "%{cudart_static_lib}": _lib_name(
+ "cudart_static",
+ cpu_value,
+ static = True,
+ ),
+ "%{cudart_static_linkopt}": _cudart_static_linkopt(cpu_value),
+ "%{cudart_lib}": _lib_name("cudart", cpu_value),
+ "%{cublas_lib}": _lib_name("cublas", cpu_value),
+ "%{cusolver_lib}": _lib_name("cusolver", cpu_value),
+ "%{cudnn_lib}": _lib_name("cudnn", cpu_value),
+ "%{cufft_lib}": _lib_name("cufft", cpu_value),
+ "%{curand_lib}": _lib_name("curand", cpu_value),
+ "%{cupti_lib}": _lib_name("cupti", cpu_value),
+ "%{cuda_include_genrules}": "",
+ "%{cuda_headers}": "",
+ },
+ )
+
+ # Create dummy files for the CUDA toolkit since they are still required by
+ # tensorflow/core/platform/default/build_config:cuda.
+ repository_ctx.file("cuda/cuda/include/cuda.h", "")
+ repository_ctx.file("cuda/cuda/include/cublas.h", "")
+ repository_ctx.file("cuda/cuda/include/cudnn.h", "")
+ repository_ctx.file("cuda/cuda/extras/CUPTI/include/cupti.h", "")
+ repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cuda", cpu_value))
+ repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cudart", cpu_value))
+ repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cudart_static", cpu_value))
+ repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cublas", cpu_value))
+ repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cusolver", cpu_value))
+ repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cudnn", cpu_value))
+ repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("curand", cpu_value))
+ repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cufft", cpu_value))
+ repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cupti", cpu_value))
+
+ # Set up cuda_config.h, which is used by
+ # tensorflow/stream_executor/dso_loader.cc.
+ _tpl(
+ repository_ctx,
+ "cuda:cuda_config.h",
+ {
+ "%{cuda_version}": _DEFAULT_CUDA_VERSION,
+ "%{cudnn_version}": _DEFAULT_CUDNN_VERSION,
+ "%{cuda_compute_capabilities}": ",".join([
+ "CudaVersion(\"%s\")" % c
+ for c in _DEFAULT_CUDA_COMPUTE_CAPABILITIES
+ ]),
+ "%{cuda_toolkit_path}": _DEFAULT_CUDA_TOOLKIT_PATH,
+ },
+ "cuda/cuda/cuda_config.h",
+ )
+
+ # If cuda_configure is not configured to build with GPU support, and the user
+ # attempts to build with --config=cuda, add a dummy build rule to intercept
+ # this and fail with an actionable error message.
+ repository_ctx.file(
+ "crosstool/error_gpu_disabled.bzl",
+ _DUMMY_CROSSTOOL_BZL_FILE,
+ )
+ repository_ctx.file("crosstool/BUILD", _DUMMY_CROSSTOOL_BUILD_FILE)
+
+def _execute(
+ repository_ctx,
+ cmdline,
+ error_msg = None,
+ error_details = None,
+ empty_stdout_fine = False):
+ """Executes an arbitrary shell command.
+
+ Args:
+ repository_ctx: the repository_ctx object
+ cmdline: list of strings, the command to execute
+ error_msg: string, a summary of the error if the command fails
+ error_details: string, details about the error or steps to fix it
+ empty_stdout_fine: bool, if True, an empty stdout result is fine, otherwise
+ it's an error
+ Return:
+ the result of repository_ctx.execute(cmdline)
+ """
+ result = repository_ctx.execute(cmdline)
+ if result.stderr or not (empty_stdout_fine or result.stdout):
+ auto_configure_fail(
+ "\n".join([
+ error_msg.strip() if error_msg else "Repository command failed",
+ result.stderr.strip(),
+ error_details if error_details else "",
+ ]),
+ )
+ return result
def _norm_path(path):
- """Returns a path with '/' and remove the trailing slash."""
- path = path.replace("\\", "/")
- if path[-1] == "/":
- path = path[:-1]
- return path
-
-
-def symlink_genrule_for_dir(repository_ctx, src_dir, dest_dir, genrule_name,
- src_files = [], dest_files = []):
- """Returns a genrule to symlink(or copy if on Windows) a set of files.
-
- If src_dir is passed, files will be read from the given directory; otherwise
- we assume files are in src_files and dest_files
- """
- if src_dir != None:
- src_dir = _norm_path(src_dir)
- dest_dir = _norm_path(dest_dir)
- files = '\n'.join(sorted(_read_dir(repository_ctx, src_dir).splitlines()))
- # Create a list with the src_dir stripped to use for outputs.
- dest_files = files.replace(src_dir, '').splitlines()
- src_files = files.splitlines()
- command = []
- if not _is_windows(repository_ctx):
- # We clear folders that might have been generated previously to avoid
- # undesired inclusions
- command.append('if [ -d "$(@D)/extras" ]; then rm $(@D)/extras -drf; fi')
- command.append('if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi')
- command.append('if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi')
- command.append('if [ -d "$(@D)/nvvm" ]; then rm $(@D)/nvvm -drf; fi')
- outs = []
- for i in range(len(dest_files)):
- if dest_files[i] != "":
- # If we have only one file to link we do not want to use the dest_dir, as
- # $(@D) will include the full path to the file.
- dest = '$(@D)/' + dest_dir + dest_files[i] if len(dest_files) != 1 else '$(@D)/' + dest_files[i]
- # On Windows, symlink is not supported, so we just copy all the files.
- cmd = 'cp -f' if _is_windows(repository_ctx) else 'ln -s'
- command.append(cmd + ' "%s" "%s"' % (src_files[i] , dest))
- outs.append(' "' + dest_dir + dest_files[i] + '",')
- genrule = _genrule(src_dir, genrule_name, " && ".join(command),
- "\n".join(outs))
- return genrule
-
+ """Returns a path with '/' and remove the trailing slash."""
+ path = path.replace("\\", "/")
+ if path[-1] == "/":
+ path = path[:-1]
+ return path
+
+def symlink_genrule_for_dir(
+ repository_ctx,
+ src_dir,
+ dest_dir,
+ genrule_name,
+ src_files = [],
+ dest_files = []):
+ """Returns a genrule to symlink(or copy if on Windows) a set of files.
+
+ If src_dir is passed, files will be read from the given directory; otherwise
+ we assume files are in src_files and dest_files
+ """
+ if src_dir != None:
+ src_dir = _norm_path(src_dir)
+ dest_dir = _norm_path(dest_dir)
+ files = "\n".join(sorted(_read_dir(repository_ctx, src_dir).splitlines()))
+
+ # Create a list with the src_dir stripped to use for outputs.
+ dest_files = files.replace(src_dir, "").splitlines()
+ src_files = files.splitlines()
+ command = []
+ if not _is_windows(repository_ctx):
+ # We clear folders that might have been generated previously to avoid
+ # undesired inclusions
+ command.append('if [ -d "$(@D)/extras" ]; then rm $(@D)/extras -drf; fi')
+ command.append('if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi')
+ command.append('if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi')
+ command.append('if [ -d "$(@D)/nvvm" ]; then rm $(@D)/nvvm -drf; fi')
+ outs = []
+ for i in range(len(dest_files)):
+ if dest_files[i] != "":
+ # If we have only one file to link we do not want to use the dest_dir, as
+ # $(@D) will include the full path to the file.
+ dest = "$(@D)/" + dest_dir + dest_files[i] if len(dest_files) != 1 else "$(@D)/" + dest_files[i]
+
+ # On Windows, symlink is not supported, so we just copy all the files.
+ cmd = "cp -f" if _is_windows(repository_ctx) else "ln -s"
+ command.append(cmd + ' "%s" "%s"' % (src_files[i], dest))
+ outs.append(' "' + dest_dir + dest_files[i] + '",')
+ genrule = _genrule(
+ src_dir,
+ genrule_name,
+ " && ".join(command),
+ "\n".join(outs),
+ )
+ return genrule
def _genrule(src_dir, genrule_name, command, outs):
- """Returns a string with a genrule.
-
- Genrule executes the given command and produces the given outputs.
- """
- return (
- 'genrule(\n' +
- ' name = "' +
- genrule_name + '",\n' +
- ' outs = [\n' +
- outs +
- '\n ],\n' +
- ' cmd = """\n' +
- command +
- '\n """,\n' +
- ')\n'
- )
+ """Returns a string with a genrule.
+ Genrule executes the given command and produces the given outputs.
+ """
+ return (
+ "genrule(\n" +
+ ' name = "' +
+ genrule_name + '",\n' +
+ " outs = [\n" +
+ outs +
+ "\n ],\n" +
+ ' cmd = """\n' +
+ command +
+ '\n """,\n' +
+ ")\n"
+ )
def _read_dir(repository_ctx, src_dir):
- """Returns a string with all files in a directory.
-
- Finds all files inside a directory, traversing subfolders and following
- symlinks. The returned string contains the full path of all files
- separated by line breaks.
- """
- if _is_windows(repository_ctx):
- src_dir = src_dir.replace("/", "\\")
- find_result = _execute(
- repository_ctx, ["cmd.exe", "/c", "dir", src_dir, "/b", "/s", "/a-d"],
- empty_stdout_fine=True)
- # src_files will be used in genrule.outs where the paths must
- # use forward slashes.
- result = find_result.stdout.replace("\\", "/")
- else:
- find_result = _execute(
- repository_ctx, ["find", src_dir, "-follow", "-type", "f"],
- empty_stdout_fine=True)
- result = find_result.stdout
- return result
+ """Returns a string with all files in a directory.
+
+ Finds all files inside a directory, traversing subfolders and following
+ symlinks. The returned string contains the full path of all files
+ separated by line breaks.
+ """
+ if _is_windows(repository_ctx):
+ src_dir = src_dir.replace("/", "\\")
+ find_result = _execute(
+ repository_ctx,
+ ["cmd.exe", "/c", "dir", src_dir, "/b", "/s", "/a-d"],
+ empty_stdout_fine = True,
+ )
+
+ # src_files will be used in genrule.outs where the paths must
+ # use forward slashes.
+ result = find_result.stdout.replace("\\", "/")
+ else:
+ find_result = _execute(
+ repository_ctx,
+ ["find", src_dir, "-follow", "-type", "f"],
+ empty_stdout_fine = True,
+ )
+ result = find_result.stdout
+ return result
def _flag_enabled(repository_ctx, flag_name):
- if flag_name in repository_ctx.os.environ:
- value = repository_ctx.os.environ[flag_name].strip()
- return value == "1"
- return False
+ if flag_name in repository_ctx.os.environ:
+ value = repository_ctx.os.environ[flag_name].strip()
+ return value == "1"
+ return False
def _use_cuda_clang(repository_ctx):
- return _flag_enabled(repository_ctx, "TF_CUDA_CLANG")
+ return _flag_enabled(repository_ctx, "TF_CUDA_CLANG")
def _compute_cuda_extra_copts(repository_ctx, compute_capabilities):
- if _use_cuda_clang(repository_ctx):
- capability_flags = ["--cuda-gpu-arch=sm_" +
- cap.replace(".", "") for cap in compute_capabilities]
- else:
- # Capabilities are handled in the "crosstool_wrapper_driver_is_not_gcc" for nvcc
- capability_flags = []
- return str(capability_flags)
+ if _use_cuda_clang(repository_ctx):
+ capability_flags = ["--cuda-gpu-arch=sm_" +
+ cap.replace(".", "") for cap in compute_capabilities]
+ else:
+ # Capabilities are handled in the "crosstool_wrapper_driver_is_not_gcc" for nvcc
+ capability_flags = []
+ return str(capability_flags)
def _create_local_cuda_repository(repository_ctx):
- """Creates the repository containing files set up to build with CUDA."""
- cuda_config = _get_cuda_config(repository_ctx)
-
- cuda_include_path = _find_cuda_include_path(repository_ctx, cuda_config)
- cudnn_header_dir = _find_cudnn_header_dir(repository_ctx,
- cuda_config.cudnn_install_basedir)
- cupti_header_dir = _find_cupti_header_dir(repository_ctx, cuda_config)
- nvvm_libdevice_dir = _find_nvvm_libdevice_dir(repository_ctx, cuda_config)
-
- # Set up symbolic links for the cuda toolkit by creating genrules to do
- # symlinking. We create one genrule for each directory we want to track under
- # cuda_toolkit_path
- cuda_toolkit_path = cuda_config.cuda_toolkit_path
- genrules = [symlink_genrule_for_dir(repository_ctx,
- cuda_include_path, "cuda/include", "cuda-include")]
- genrules.append(symlink_genrule_for_dir(repository_ctx,
- nvvm_libdevice_dir, "cuda/nvvm/libdevice", "cuda-nvvm"))
- genrules.append(symlink_genrule_for_dir(repository_ctx,
- cupti_header_dir, "cuda/extras/CUPTI/include", "cuda-extras"))
-
- cuda_libs = _find_libs(repository_ctx, cuda_config)
- cuda_lib_src = []
- cuda_lib_dest = []
- for lib in cuda_libs.values():
- cuda_lib_src.append(lib.path)
- cuda_lib_dest.append("cuda/lib/" + lib.file_name)
- genrules.append(symlink_genrule_for_dir(repository_ctx, None, "", "cuda-lib",
- cuda_lib_src, cuda_lib_dest))
-
- # Set up the symbolic links for cudnn if cndnn was not installed to
- # CUDA_TOOLKIT_PATH.
- included_files = _read_dir(repository_ctx, cuda_include_path).replace(
- cuda_include_path, '').splitlines()
- if '/cudnn.h' not in included_files:
- genrules.append(symlink_genrule_for_dir(repository_ctx, None,
- "cuda/include/", "cudnn-include", [cudnn_header_dir + "/cudnn.h"],
- ["cudnn.h"]))
- else:
- genrules.append(
- 'filegroup(\n' +
+ """Creates the repository containing files set up to build with CUDA."""
+ cuda_config = _get_cuda_config(repository_ctx)
+
+ cuda_include_path = _find_cuda_include_path(repository_ctx, cuda_config)
+ cudnn_header_dir = _find_cudnn_header_dir(
+ repository_ctx,
+ cuda_config.cudnn_install_basedir,
+ )
+ cupti_header_dir = _find_cupti_header_dir(repository_ctx, cuda_config)
+ nvvm_libdevice_dir = _find_nvvm_libdevice_dir(repository_ctx, cuda_config)
+
+ # Set up symbolic links for the cuda toolkit by creating genrules to do
+ # symlinking. We create one genrule for each directory we want to track under
+ # cuda_toolkit_path
+ cuda_toolkit_path = cuda_config.cuda_toolkit_path
+ genrules = [symlink_genrule_for_dir(
+ repository_ctx,
+ cuda_include_path,
+ "cuda/include",
+ "cuda-include",
+ )]
+ genrules.append(symlink_genrule_for_dir(
+ repository_ctx,
+ nvvm_libdevice_dir,
+ "cuda/nvvm/libdevice",
+ "cuda-nvvm",
+ ))
+ genrules.append(symlink_genrule_for_dir(
+ repository_ctx,
+ cupti_header_dir,
+ "cuda/extras/CUPTI/include",
+ "cuda-extras",
+ ))
+
+ cuda_libs = _find_libs(repository_ctx, cuda_config)
+ cuda_lib_src = []
+ cuda_lib_dest = []
+ for lib in cuda_libs.values():
+ cuda_lib_src.append(lib.path)
+ cuda_lib_dest.append("cuda/lib/" + lib.file_name)
+ genrules.append(symlink_genrule_for_dir(
+ repository_ctx,
+ None,
+ "",
+ "cuda-lib",
+ cuda_lib_src,
+ cuda_lib_dest,
+ ))
+
+ # Set up the symbolic links for cudnn if cndnn was not installed to
+ # CUDA_TOOLKIT_PATH.
+ included_files = _read_dir(repository_ctx, cuda_include_path).replace(
+ cuda_include_path,
+ "",
+ ).splitlines()
+ if "/cudnn.h" not in included_files:
+ genrules.append(symlink_genrule_for_dir(
+ repository_ctx,
+ None,
+ "cuda/include/",
+ "cudnn-include",
+ [cudnn_header_dir + "/cudnn.h"],
+ ["cudnn.h"],
+ ))
+ else:
+ genrules.append(
+ "filegroup(\n" +
' name = "cudnn-include",\n' +
- ' srcs = [],\n' +
- ')\n'
+ " srcs = [],\n" +
+ ")\n",
)
- # Set up BUILD file for cuda/
- _tpl(repository_ctx, "cuda:build_defs.bzl",
- {
- "%{cuda_is_configured}": "True",
- "%{cuda_extra_copts}": _compute_cuda_extra_copts(
- repository_ctx, cuda_config.compute_capabilities),
- })
- _tpl(repository_ctx, "cuda:BUILD",
- {
- "%{cuda_driver_lib}": cuda_libs["cuda"].file_name,
- "%{cudart_static_lib}": cuda_libs["cudart_static"].file_name,
- "%{cudart_static_linkopt}": _cudart_static_linkopt(
- cuda_config.cpu_value),
- "%{cudart_lib}": cuda_libs["cudart"].file_name,
- "%{cublas_lib}": cuda_libs["cublas"].file_name,
- "%{cusolver_lib}": cuda_libs["cusolver"].file_name,
- "%{cudnn_lib}": cuda_libs["cudnn"].file_name,
- "%{cufft_lib}": cuda_libs["cufft"].file_name,
- "%{curand_lib}": cuda_libs["curand"].file_name,
- "%{cupti_lib}": cuda_libs["cupti"].file_name,
- "%{cuda_include_genrules}": "\n".join(genrules),
- "%{cuda_headers}": ('":cuda-include",\n' +
- ' ":cudnn-include",')
- })
-
- is_cuda_clang = _use_cuda_clang(repository_ctx)
-
- should_download_clang = is_cuda_clang and _flag_enabled(
- repository_ctx, _TF_DOWNLOAD_CLANG)
- if should_download_clang:
- download_clang(repository_ctx, "crosstool/extra_tools")
-
- # Set up crosstool/
- cc = find_cc(repository_ctx)
- cc_fullpath = cc if not should_download_clang else "crosstool/" + cc
-
- host_compiler_includes = _host_compiler_includes(repository_ctx, cc_fullpath)
- cuda_defines = {}
- if is_cuda_clang:
- cuda_defines["%{host_compiler_path}"] = str(cc)
- cuda_defines["%{host_compiler_warnings}"] = """
+ # Set up BUILD file for cuda/
+ _tpl(
+ repository_ctx,
+ "cuda:build_defs.bzl",
+ {
+ "%{cuda_is_configured}": "True",
+ "%{cuda_extra_copts}": _compute_cuda_extra_copts(
+ repository_ctx,
+ cuda_config.compute_capabilities,
+ ),
+ },
+ )
+ _tpl(
+ repository_ctx,
+ "cuda:BUILD.windows" if _is_windows(repository_ctx) else "cuda:BUILD",
+ {
+ "%{cuda_driver_lib}": cuda_libs["cuda"].file_name,
+ "%{cudart_static_lib}": cuda_libs["cudart_static"].file_name,
+ "%{cudart_static_linkopt}": _cudart_static_linkopt(
+ cuda_config.cpu_value,
+ ),
+ "%{cudart_lib}": cuda_libs["cudart"].file_name,
+ "%{cublas_lib}": cuda_libs["cublas"].file_name,
+ "%{cusolver_lib}": cuda_libs["cusolver"].file_name,
+ "%{cudnn_lib}": cuda_libs["cudnn"].file_name,
+ "%{cufft_lib}": cuda_libs["cufft"].file_name,
+ "%{curand_lib}": cuda_libs["curand"].file_name,
+ "%{cupti_lib}": cuda_libs["cupti"].file_name,
+ "%{cuda_include_genrules}": "\n".join(genrules),
+ "%{cuda_headers}": ('":cuda-include",\n' +
+ ' ":cudnn-include",'),
+ },
+ "cuda/BUILD",
+ )
+
+ is_cuda_clang = _use_cuda_clang(repository_ctx)
+
+ should_download_clang = is_cuda_clang and _flag_enabled(
+ repository_ctx,
+ _TF_DOWNLOAD_CLANG,
+ )
+ if should_download_clang:
+ download_clang(repository_ctx, "crosstool/extra_tools")
+
+ # Set up crosstool/
+ cc = find_cc(repository_ctx)
+ cc_fullpath = cc if not should_download_clang else "crosstool/" + cc
+
+ host_compiler_includes = _host_compiler_includes(repository_ctx, cc_fullpath)
+ cuda_defines = {}
+ if is_cuda_clang:
+ cuda_defines["%{host_compiler_path}"] = str(cc)
+ cuda_defines["%{host_compiler_warnings}"] = """
# Some parts of the codebase set -Werror and hit this warning, so
# switch it off for now.
flag: "-Wno-invalid-partial-specialization"
"""
- cuda_defines["%{host_compiler_includes}"] = host_compiler_includes
- _tpl(repository_ctx, "crosstool:BUILD", {"%{linker_files}": ":empty"})
- repository_ctx.file("crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc", "")
- else:
- cuda_defines["%{host_compiler_path}"] = "clang/bin/crosstool_wrapper_driver_is_not_gcc"
- cuda_defines["%{host_compiler_warnings}"] = ""
- # TODO(klimek): We currently need to inject "/" as builtin directory path
- # to disable bazel's dependency checks.
- # The problem is that:
- # - the python rules symlink the python headers into the bazel root
- # - the rules use 'includes' in the BUILD file to redirect includes of the
- # python headers through those paths
- # - bazel currently uses -isystem for include paths specified via 'includes'
- # - gcc follows symlinks when resolving files via -isystem paths, and puts
- # the resolved paths into the .d file, which makes the dependency check
- # fail for bazel
- # There are multiple possible ways to solve this:
- # 1. make bazel not use -isystem for paths specified via 'includes'
- # 2. cp the headers instead of symlinking them
- #
- # Once this is fixed, the right builtin directory path is:
- # (host_compiler_includes +
- # "\n cxx_builtin_include_directory: \"%s\"" % cuda_include_path)
- # The cuda directory needs to be passed, as there is currently no rule
- # providing the cuda headers in the same way the python headers are
- # provided.
- cuda_defines["%{host_compiler_includes}"] = "\n cxx_builtin_include_directory: \"/\""
- nvcc_path = str(repository_ctx.path("%s/bin/nvcc%s" %
- (cuda_config.cuda_toolkit_path,
- ".exe" if cuda_config.cpu_value == "Windows" else "")))
- _tpl(repository_ctx, "crosstool:BUILD",
- {"%{linker_files}": ":crosstool_wrapper_driver_is_not_gcc"})
- _tpl(repository_ctx,
- "crosstool:clang/bin/crosstool_wrapper_driver_is_not_gcc",
- {
- "%{cpu_compiler}": str(cc),
- "%{cuda_version}": cuda_config.cuda_version,
- "%{nvcc_path}": nvcc_path,
- "%{gcc_host_compiler_path}": str(cc),
- "%{cuda_compute_capabilities}": ", ".join(
- ["\"%s\"" % c for c in cuda_config.compute_capabilities]),
- })
- _tpl(repository_ctx, "crosstool:CROSSTOOL", cuda_defines, out="crosstool/CROSSTOOL")
-
- # Set up cuda_config.h, which is used by
- # tensorflow/stream_executor/dso_loader.cc.
- _tpl(repository_ctx, "cuda:cuda_config.h",
- {
- "%{cuda_version}": cuda_config.cuda_version,
- "%{cudnn_version}": cuda_config.cudnn_version,
- "%{cuda_compute_capabilities}": ",".join(
- ["CudaVersion(\"%s\")" % c
- for c in cuda_config.compute_capabilities]),
- "%{cuda_toolkit_path}": cuda_config.cuda_toolkit_path,
- }, "cuda/cuda/cuda_config.h")
+ cuda_defines["%{host_compiler_includes}"] = host_compiler_includes
+ _tpl(repository_ctx, "crosstool:BUILD", {"%{linker_files}": ":empty", "%{win_linker_files}": ":empty"})
+ repository_ctx.file("crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc", "")
+ repository_ctx.file("crosstool/windows/msvc_wrapper_for_nvcc.py", "")
+ repository_ctx.file("crosstool/windows/msvc_wrapper_for_nvcc.bat", "")
+ else:
+ cuda_defines["%{host_compiler_path}"] = "clang/bin/crosstool_wrapper_driver_is_not_gcc"
+ cuda_defines["%{host_compiler_warnings}"] = ""
+
+ # TODO(klimek): We currently need to inject "/" as builtin directory path
+ # to disable bazel's dependency checks.
+ # The problem is that:
+ # - the python rules symlink the python headers into the bazel root
+ # - the rules use 'includes' in the BUILD file to redirect includes of the
+ # python headers through those paths
+ # - bazel currently uses -isystem for include paths specified via 'includes'
+ # - gcc follows symlinks when resolving files via -isystem paths, and puts
+ # the resolved paths into the .d file, which makes the dependency check
+ # fail for bazel
+ # There are multiple possible ways to solve this:
+ # 1. make bazel not use -isystem for paths specified via 'includes'
+ # 2. cp the headers instead of symlinking them
+ #
+ # Once this is fixed, the right builtin directory path is:
+ # (host_compiler_includes +
+ # "\n cxx_builtin_include_directory: \"%s\"" % cuda_include_path)
+ # The cuda directory needs to be passed, as there is currently no rule
+ # providing the cuda headers in the same way the python headers are
+ # provided.
+ cuda_defines["%{host_compiler_includes}"] = "\n cxx_builtin_include_directory: \"/\""
+ nvcc_path = str(repository_ctx.path("%s/bin/nvcc%s" %
+ (
+ cuda_config.cuda_toolkit_path,
+ ".exe" if _is_windows(repository_ctx) else "",
+ )))
+ _tpl(
+ repository_ctx,
+ "crosstool:BUILD",
+ {
+ "%{linker_files}": ":crosstool_wrapper_driver_is_not_gcc",
+ "%{win_linker_files}": ":windows_msvc_wrapper_files",
+ },
+ )
+ wrapper_defines = {
+ "%{cpu_compiler}": str(cc),
+ "%{cuda_version}": cuda_config.cuda_version,
+ "%{nvcc_path}": nvcc_path,
+ "%{gcc_host_compiler_path}": str(cc),
+ "%{cuda_compute_capabilities}": ", ".join(
+ ["\"%s\"" % c for c in cuda_config.compute_capabilities],
+ ),
+ "%{nvcc_tmp_dir}": _get_nvcc_tmp_dir_for_windows(repository_ctx),
+ }
+ _tpl(
+ repository_ctx,
+ "crosstool:clang/bin/crosstool_wrapper_driver_is_not_gcc",
+ wrapper_defines,
+ )
+ _tpl(
+ repository_ctx,
+ "crosstool:windows/msvc_wrapper_for_nvcc.py",
+ wrapper_defines,
+ )
+ _tpl(
+ repository_ctx,
+ "crosstool:windows/msvc_wrapper_for_nvcc.bat",
+ {
+ "%{python_binary}": _get_python_bin(repository_ctx),
+ },
+ )
+
+ _tpl(
+ repository_ctx,
+ "crosstool:CROSSTOOL",
+ cuda_defines + _get_win_cuda_defines(repository_ctx),
+ out = "crosstool/CROSSTOOL",
+ )
+
+ # Set up cuda_config.h, which is used by
+ # tensorflow/stream_executor/dso_loader.cc.
+ _tpl(
+ repository_ctx,
+ "cuda:cuda_config.h",
+ {
+ "%{cuda_version}": cuda_config.cuda_version,
+ "%{cudnn_version}": cuda_config.cudnn_version,
+ "%{cuda_compute_capabilities}": ",".join(
+ [
+ "CudaVersion(\"%s\")" % c
+ for c in cuda_config.compute_capabilities
+ ],
+ ),
+ "%{cuda_toolkit_path}": cuda_config.cuda_toolkit_path,
+ },
+ "cuda/cuda/cuda_config.h",
+ )
def _create_remote_cuda_repository(repository_ctx, remote_config_repo):
- """Creates pointers to a remotely configured repo set up to build with CUDA."""
- _tpl(repository_ctx, "cuda:build_defs.bzl",
- {
- "%{cuda_is_configured}": "True",
- "%{cuda_extra_copts}": _compute_cuda_extra_copts(
- repository_ctx, _compute_capabilities(repository_ctx)),
-
- })
- _tpl(repository_ctx, "cuda:remote.BUILD",
- {
- "%{remote_cuda_repo}": remote_config_repo,
- }, "cuda/BUILD")
- _tpl(repository_ctx, "crosstool:remote.BUILD", {
- "%{remote_cuda_repo}": remote_config_repo,
- }, "crosstool/BUILD")
+ """Creates pointers to a remotely configured repo set up to build with CUDA."""
+ _tpl(
+ repository_ctx,
+ "cuda:build_defs.bzl",
+ {
+ "%{cuda_is_configured}": "True",
+ "%{cuda_extra_copts}": _compute_cuda_extra_copts(
+ repository_ctx,
+ _compute_capabilities(repository_ctx),
+ ),
+ },
+ )
+ _tpl(
+ repository_ctx,
+ "cuda:remote.BUILD",
+ {
+ "%{remote_cuda_repo}": remote_config_repo,
+ },
+ "cuda/BUILD",
+ )
+ _tpl(repository_ctx, "crosstool:remote.BUILD", {
+ "%{remote_cuda_repo}": remote_config_repo,
+ }, "crosstool/BUILD")
def _cuda_autoconf_impl(repository_ctx):
- """Implementation of the cuda_autoconf repository rule."""
- if not _enable_cuda(repository_ctx):
- _create_dummy_repository(repository_ctx)
- else:
- if _TF_CUDA_CONFIG_REPO in repository_ctx.os.environ:
- _create_remote_cuda_repository(repository_ctx,
- repository_ctx.os.environ[_TF_CUDA_CONFIG_REPO])
+ """Implementation of the cuda_autoconf repository rule."""
+ if not _enable_cuda(repository_ctx):
+ _create_dummy_repository(repository_ctx)
+ elif _TF_CUDA_CONFIG_REPO in repository_ctx.os.environ:
+ _create_remote_cuda_repository(
+ repository_ctx,
+ repository_ctx.os.environ[_TF_CUDA_CONFIG_REPO],
+ )
else:
- _create_local_cuda_repository(repository_ctx)
-
+ _create_local_cuda_repository(repository_ctx)
cuda_configure = repository_rule(
implementation = _cuda_autoconf_impl,
@@ -1181,6 +1457,7 @@ cuda_configure = repository_rule(
_TF_CUDA_COMPUTE_CAPABILITIES,
_TF_CUDA_CONFIG_REPO,
"NVVMIR_LIBRARY_DIR",
+ _PYTHON_BIN_PATH,
],
)
diff --git a/third_party/llvm/llvm.autogenerated.BUILD b/third_party/llvm/llvm.autogenerated.BUILD
index 8f65853918..bf9f9ca9cf 100644
--- a/third_party/llvm/llvm.autogenerated.BUILD
+++ b/third_party/llvm/llvm.autogenerated.BUILD
@@ -8,13 +8,14 @@ exports_files(["LICENSE.TXT"])
load(
"@org_tensorflow//third_party/llvm:llvm.bzl",
- "LLVM_COPTS",
- "LLVM_DEFINES",
- "LLVM_LINKOPTS",
"cmake_var_string",
"expand_cmake_vars",
"gentbl",
"llvm_all_cmake_vars",
+ "llvm_copts",
+ "llvm_defines",
+ "llvm_linkopts",
+ "llvm_support_platform_specific_srcs_glob",
)
load(
"@org_tensorflow//third_party:common.bzl",
@@ -121,7 +122,7 @@ cc_library(
"include/llvm/Config/config.h",
"include/llvm/Config/llvm-config.h",
],
- defines = LLVM_DEFINES,
+ defines = llvm_defines,
includes = ["include"],
)
@@ -198,7 +199,8 @@ cc_binary(
"utils/TableGen/*.cpp",
"utils/TableGen/*.h",
]),
- linkopts = LLVM_LINKOPTS,
+ copts = llvm_copts,
+ linkopts = llvm_linkopts,
stamp = 0,
deps = [
":config",
@@ -214,7 +216,8 @@ cc_binary(
"utils/FileCheck/*.cpp",
"utils/FileCheck/*.h",
]),
- linkopts = LLVM_LINKOPTS,
+ copts = llvm_copts,
+ linkopts = llvm_linkopts,
stamp = 0,
deps = [":support"],
)
@@ -385,8 +388,7 @@ cc_library(
"include/llvm/Target/AArch64/AsmParser/*.inc",
"lib/Target/AArch64/AsmParser/*.h",
]),
- copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/AArch64"],
- defines = LLVM_DEFINES,
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/AArch64"],
deps = [
":aarch64_desc",
":aarch64_info",
@@ -411,8 +413,7 @@ cc_library(
"include/llvm/Target/AArch64/InstPrinter/*.inc",
"lib/Target/AArch64/InstPrinter/*.h",
]),
- copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/AArch64"],
- defines = LLVM_DEFINES,
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/AArch64"],
deps = [
":aarch64_target_gen",
":aarch64_utils",
@@ -435,8 +436,7 @@ cc_library(
"include/llvm/Target/AArch64/*.inc",
"lib/Target/AArch64/*.h",
]),
- copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/AArch64"],
- defines = LLVM_DEFINES,
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/AArch64"],
deps = [
":aarch64_asm_printer",
":aarch64_desc",
@@ -469,8 +469,7 @@ cc_library(
"include/llvm/Target/AArch64/MCTargetDesc/*.inc",
"lib/Target/AArch64/MCTargetDesc/*.h",
]),
- copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/AArch64"],
- defines = LLVM_DEFINES,
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/AArch64"],
deps = [
":aarch64_asm_printer",
":aarch64_info",
@@ -497,8 +496,7 @@ cc_library(
"include/llvm/Target/AArch64/Disassembler/*.inc",
"lib/Target/AArch64/Disassembler/*.h",
]),
- copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/AArch64"],
- defines = LLVM_DEFINES,
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/AArch64"],
deps = [
":aarch64_desc",
":aarch64_info",
@@ -526,8 +524,7 @@ cc_library(
"lib/Target/AArch64/AArch64*.h",
"lib/Target/AArch64/TargetInfo/*.h",
]),
- copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/AArch64"],
- defines = LLVM_DEFINES,
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/AArch64"],
deps = [
":code_gen",
":config",
@@ -550,8 +547,7 @@ cc_library(
"include/llvm/Target/AArch64/Utils/*.inc",
"lib/Target/AArch64/Utils/*.h",
]),
- copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/AArch64"],
- defines = LLVM_DEFINES,
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/AArch64"],
deps = [
":aarch64_target_gen",
":config",
@@ -573,8 +569,7 @@ cc_library(
"include/llvm/Transforms/AggressiveInstCombine/*.def",
"include/llvm/Transforms/AggressiveInstCombine/*.inc",
]),
- copts = LLVM_COPTS,
- defines = LLVM_DEFINES,
+ copts = llvm_copts,
deps = [
":analysis",
":config",
@@ -599,8 +594,7 @@ cc_library(
"include/llvm/Analysis/*.def",
"include/llvm/Analysis/*.inc",
]),
- copts = LLVM_COPTS,
- defines = LLVM_DEFINES,
+ copts = llvm_copts,
deps = [
":binary_format",
":config",
@@ -624,8 +618,7 @@ cc_library(
"include/llvm/Target/AMDGPU/MCTargetDesc/*.inc",
"lib/Target/AMDGPU/MCTargetDesc/*.h",
]),
- copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/AMDGPU"],
- defines = LLVM_DEFINES,
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/AMDGPU"],
deps = [
":amdgpu_asm_printer",
":amdgpu_info",
@@ -650,8 +643,7 @@ cc_library(
"include/llvm/Target/AMDGPU/Disassembler/*.inc",
"lib/Target/AMDGPU/Disassembler/*.h",
]),
- copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/AMDGPU"],
- defines = LLVM_DEFINES,
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/AMDGPU"],
deps = [
":amdgpu_desc",
":amdgpu_info",
@@ -676,8 +668,7 @@ cc_library(
"include/llvm/Target/AMDGPU/TargetInfo/*.inc",
"lib/Target/AMDGPU/TargetInfo/*.h",
]),
- copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/AMDGPU"],
- defines = LLVM_DEFINES,
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/AMDGPU"],
deps = [
":amdgpu_target_gen",
":config",
@@ -699,8 +690,7 @@ cc_library(
"include/llvm/Target/AMDGPU/Utils/*.inc",
"lib/Target/AMDGPU/Utils/*.h",
]),
- copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/AMDGPU"],
- defines = LLVM_DEFINES,
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/AMDGPU"],
deps = [
":amdgpu_target_gen",
":config",
@@ -723,8 +713,7 @@ cc_library(
"include/llvm/Target/AMDGPU/AsmParser/*.inc",
"lib/Target/AMDGPU/AsmParser/*.h",
]),
- copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/AMDGPU"],
- defines = LLVM_DEFINES,
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/AMDGPU"],
deps = [
":amdgpu_desc",
":amdgpu_info",
@@ -749,8 +738,7 @@ cc_library(
"include/llvm/Target/AMDGPU/InstPrinter/*.inc",
"lib/Target/AMDGPU/InstPrinter/*.h",
]),
- copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/AMDGPU"],
- defines = LLVM_DEFINES,
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/AMDGPU"],
deps = [
":amdgpu_utils",
":config",
@@ -772,8 +760,7 @@ cc_library(
"include/llvm/Target/AMDGPU/*.inc",
"lib/Target/AMDGPU/*.h",
]),
- copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/AMDGPU"],
- defines = LLVM_DEFINES,
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/AMDGPU"],
deps = [
":amdgpu_asm_printer",
":amdgpu_desc",
@@ -809,8 +796,7 @@ cc_library(
"include/llvm/Target/ARM/AsmParser/*.inc",
"lib/Target/ARM/AsmParser/*.h",
]),
- copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/ARM"],
- defines = LLVM_DEFINES,
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/ARM"],
deps = [
":arm_desc",
":arm_info",
@@ -836,8 +822,7 @@ cc_library(
"lib/Target/ARM/*.h",
"lib/Target/ARM/InstPrinter/*.h",
]),
- copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/ARM"],
- defines = LLVM_DEFINES,
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/ARM"],
deps = [
":arm_info",
":arm_target_gen",
@@ -861,8 +846,7 @@ cc_library(
"include/llvm/Target/ARM/*.inc",
"lib/Target/ARM/*.h",
]),
- copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/ARM"],
- defines = LLVM_DEFINES,
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/ARM"],
deps = [
":analysis",
":arm_asm_printer",
@@ -898,8 +882,7 @@ cc_library(
"include/llvm/Target/ARM/MCTargetDesc/*.inc",
"lib/Target/ARM/MCTargetDesc/*.h",
]),
- copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/ARM"],
- defines = LLVM_DEFINES,
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/ARM"],
deps = [
":arm_asm_printer",
":arm_info",
@@ -927,8 +910,7 @@ cc_library(
"include/llvm/Target/ARM/Disassembler/*.inc",
"lib/Target/ARM/Disassembler/*.h",
]),
- copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/ARM"],
- defines = LLVM_DEFINES,
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/ARM"],
deps = [
":arm_desc",
":arm_info",
@@ -953,8 +935,7 @@ cc_library(
"include/llvm/Target/ARM/TargetInfo/*.inc",
"lib/Target/ARM/TargetInfo/*.h",
]),
- copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/ARM"],
- defines = LLVM_DEFINES,
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/ARM"],
deps = [
":arm_target_gen",
":config",
@@ -977,8 +958,7 @@ cc_library(
"include/llvm/Target/ARM/Utils/*.inc",
"lib/Target/ARM/Utils/*.h",
]),
- copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/ARM"],
- defines = LLVM_DEFINES,
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/ARM"],
deps = [
":arm_target_gen",
":config",
@@ -1000,8 +980,7 @@ cc_library(
"include/llvm/AsmParser/*.def",
"include/llvm/AsmParser/*.inc",
]),
- copts = LLVM_COPTS,
- defines = LLVM_DEFINES,
+ copts = llvm_copts,
deps = [
":binary_format",
":config",
@@ -1024,8 +1003,7 @@ cc_library(
"include/llvm/CodeGen/AsmPrinter/*.inc",
"lib/CodeGen/AsmPrinter/*.def",
]),
- copts = LLVM_COPTS,
- defines = LLVM_DEFINES,
+ copts = llvm_copts,
deps = [
":analysis",
":binary_format",
@@ -1056,8 +1034,7 @@ cc_library(
"include/llvm/BinaryFormat/ELFRelocs/*.def",
"include/llvm/BinaryFormat/WasmRelocs/*.def",
]),
- copts = LLVM_COPTS,
- defines = LLVM_DEFINES,
+ copts = llvm_copts,
deps = [
":config",
":support",
@@ -1078,8 +1055,7 @@ cc_library(
"include/llvm/Bitcode/Reader/*.inc",
"include/llvm/Bitcode/BitstreamReader.h",
]),
- copts = LLVM_COPTS,
- defines = LLVM_DEFINES,
+ copts = llvm_copts,
deps = [
":config",
":core",
@@ -1103,8 +1079,7 @@ cc_library(
"include/llvm/Bitcode/BitcodeWriterPass.h",
"include/llvm/Bitcode/BitstreamWriter.h",
]),
- copts = LLVM_COPTS,
- defines = LLVM_DEFINES,
+ copts = llvm_copts,
deps = [
":analysis",
":config",
@@ -1129,8 +1104,7 @@ cc_library(
"include/llvm/CodeGen/*.inc",
"include/llvm/CodeGen/**/*.h",
]),
- copts = LLVM_COPTS,
- defines = LLVM_DEFINES,
+ copts = llvm_copts,
deps = [
":analysis",
":bit_reader",
@@ -1168,8 +1142,7 @@ cc_library(
"include/llvm/*.h",
"include/llvm/Analysis/*.def",
]),
- copts = LLVM_COPTS,
- defines = LLVM_DEFINES,
+ copts = llvm_copts,
deps = [
":attributes_compat_gen",
":attributes_gen",
@@ -1194,8 +1167,7 @@ cc_library(
"include/llvm/DebugInfo/CodeView/*.def",
"include/llvm/DebugInfo/CodeView/*.inc",
]),
- copts = LLVM_COPTS,
- defines = LLVM_DEFINES,
+ copts = llvm_copts,
deps = [
":binary_format",
":config",
@@ -1217,8 +1189,7 @@ cc_library(
"include/llvm/DebugInfo/MSF/*.def",
"include/llvm/DebugInfo/MSF/*.inc",
]),
- copts = LLVM_COPTS,
- defines = LLVM_DEFINES,
+ copts = llvm_copts,
deps = [
":config",
":support",
@@ -1238,8 +1209,7 @@ cc_library(
"include/llvm/Demangle/*.def",
"include/llvm/Demangle/*.inc",
]),
- copts = LLVM_COPTS,
- defines = LLVM_DEFINES,
+ copts = llvm_copts,
deps = [":config"],
)
@@ -1256,8 +1226,7 @@ cc_library(
"include/llvm/ExecutionEngine/*.def",
"include/llvm/ExecutionEngine/*.inc",
]),
- copts = LLVM_COPTS,
- defines = LLVM_DEFINES,
+ copts = llvm_copts,
deps = [
":config",
":core",
@@ -1282,8 +1251,7 @@ cc_library(
"include/llvm/CodeGen/GlobalISel/*.def",
"include/llvm/CodeGen/GlobalISel/*.inc",
]),
- copts = LLVM_COPTS,
- defines = LLVM_DEFINES,
+ copts = llvm_copts,
deps = [
":analysis",
":code_gen",
@@ -1313,8 +1281,7 @@ cc_library(
"include/llvm/Transforms/InstrProfiling.h",
"include/llvm/Transforms/PGOInstrumentation.h",
]),
- copts = LLVM_COPTS,
- defines = LLVM_DEFINES,
+ copts = llvm_copts,
deps = [
":analysis",
":config",
@@ -1339,8 +1306,7 @@ cc_library(
"include/llvm/Transforms/InstCombine/*.def",
"include/llvm/Transforms/InstCombine/*.inc",
]),
- copts = LLVM_COPTS,
- defines = LLVM_DEFINES,
+ copts = llvm_copts,
deps = [
":analysis",
":config",
@@ -1367,8 +1333,7 @@ cc_library(
"include/llvm/Transforms/IPO/*.def",
"include/llvm/Transforms/IPO/*.inc",
]),
- copts = LLVM_COPTS,
- defines = LLVM_DEFINES,
+ copts = llvm_copts,
deps = [
":aggressive_inst_combine",
":analysis",
@@ -1402,8 +1367,7 @@ cc_library(
"include/llvm/IRReader/*.def",
"include/llvm/IRReader/*.inc",
]),
- copts = LLVM_COPTS,
- defines = LLVM_DEFINES,
+ copts = llvm_copts,
deps = [
":asm_parser",
":bit_reader",
@@ -1426,8 +1390,7 @@ cc_library(
"include/llvm/Linker/*.def",
"include/llvm/Linker/*.inc",
]),
- copts = LLVM_COPTS,
- defines = LLVM_DEFINES,
+ copts = llvm_copts,
deps = [
":config",
":core",
@@ -1449,8 +1412,7 @@ cc_library(
"include/llvm/MC/*.def",
"include/llvm/MC/*.inc",
]),
- copts = LLVM_COPTS,
- defines = LLVM_DEFINES,
+ copts = llvm_copts,
deps = [
":binary_format",
":config",
@@ -1472,8 +1434,7 @@ cc_library(
"include/llvm/MC/MCDisassembler/*.def",
"include/llvm/MC/MCDisassembler/*.inc",
]),
- copts = LLVM_COPTS,
- defines = LLVM_DEFINES,
+ copts = llvm_copts,
deps = [
":config",
":mc",
@@ -1494,8 +1455,7 @@ cc_library(
"include/llvm/MC/MCParser/*.def",
"include/llvm/MC/MCParser/*.inc",
]),
- copts = LLVM_COPTS,
- defines = LLVM_DEFINES,
+ copts = llvm_copts,
deps = [
":config",
":mc",
@@ -1516,8 +1476,7 @@ cc_library(
"include/llvm/Target/NVPTX/InstPrinter/*.inc",
"lib/Target/NVPTX/InstPrinter/*.h",
]),
- copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/NVPTX"],
- defines = LLVM_DEFINES,
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/NVPTX"],
deps = [
"nvptx_target_gen",
":attributes_gen",
@@ -1541,8 +1500,7 @@ cc_library(
"include/llvm/Target/NVPTX/*.inc",
"lib/Target/NVPTX/*.h",
]),
- copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/NVPTX"],
- defines = LLVM_DEFINES,
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/NVPTX"],
deps = [
":analysis",
":asm_printer",
@@ -1576,8 +1534,7 @@ cc_library(
"include/llvm/Target/NVPTX/MCTargetDesc/*.inc",
"lib/Target/NVPTX/MCTargetDesc/*.h",
]),
- copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/NVPTX"],
- defines = LLVM_DEFINES,
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/NVPTX"],
deps = [
"nvptx_target_gen",
":config",
@@ -1603,8 +1560,7 @@ cc_library(
"lib/Target/NVPTX/NVPTX.h",
"lib/Target/NVPTX/TargetInfo/*.h",
]),
- copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/NVPTX"],
- defines = LLVM_DEFINES,
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/NVPTX"],
deps = [
"nvptx_target_gen",
":attributes_gen",
@@ -1628,8 +1584,7 @@ cc_library(
"include/llvm/Object/*.def",
"include/llvm/Object/*.inc",
]),
- copts = LLVM_COPTS,
- defines = LLVM_DEFINES,
+ copts = llvm_copts,
deps = [
":binary_format",
":bit_reader",
@@ -1655,8 +1610,7 @@ cc_library(
"include/llvm/Transforms/ObjCARC/*.def",
"include/llvm/Transforms/ObjCARC/*.inc",
]),
- copts = LLVM_COPTS,
- defines = LLVM_DEFINES,
+ copts = llvm_copts,
deps = [
":analysis",
":config",
@@ -1679,8 +1633,7 @@ cc_library(
"include/llvm/ExecutionEngine/Orc/*.def",
"include/llvm/ExecutionEngine/Orc/*.inc",
]),
- copts = LLVM_COPTS,
- defines = LLVM_DEFINES,
+ copts = llvm_copts,
deps = [
":config",
":core",
@@ -1707,8 +1660,7 @@ cc_library(
"include/llvm/Target/PowerPC/AsmParser/*.inc",
"lib/Target/PowerPC/AsmParser/*.h",
]),
- copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/PowerPC"],
- defines = LLVM_DEFINES,
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/PowerPC"],
deps = [
":config",
":mc",
@@ -1732,8 +1684,7 @@ cc_library(
"include/llvm/Target/PowerPC/InstPrinter/*.inc",
"lib/Target/PowerPC/InstPrinter/*.h",
]),
- copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/PowerPC"],
- defines = LLVM_DEFINES,
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/PowerPC"],
deps = [
":attributes_gen",
":config",
@@ -1759,8 +1710,7 @@ cc_library(
"include/llvm/Target/PowerPC/*.inc",
"lib/Target/PowerPC/*.h",
]),
- copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/PowerPC"],
- defines = LLVM_DEFINES,
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/PowerPC"],
deps = [
":analysis",
":asm_printer",
@@ -1792,8 +1742,7 @@ cc_library(
"include/llvm/Target/PowerPC/MCTargetDesc/*.inc",
"lib/Target/PowerPC/MCTargetDesc/*.h",
]),
- copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/PowerPC"],
- defines = LLVM_DEFINES,
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/PowerPC"],
deps = [
":attributes_gen",
":config",
@@ -1820,8 +1769,7 @@ cc_library(
"include/llvm/Target/PowerPC/Disassembler/*.inc",
"lib/Target/PowerPC/Disassembler/*.h",
]),
- copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/PowerPC"],
- defines = LLVM_DEFINES,
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/PowerPC"],
deps = [
":config",
":mc_disassembler",
@@ -1845,8 +1793,7 @@ cc_library(
"lib/Target/PowerPC/PPC*.h",
"lib/Target/PowerPC/TargetInfo/*.h",
]),
- copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/PowerPC"],
- defines = LLVM_DEFINES,
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/PowerPC"],
deps = [
":attributes_gen",
":config",
@@ -1870,8 +1817,7 @@ cc_library(
"include/llvm/ProfileData/*.def",
"include/llvm/ProfileData/*.inc",
]),
- copts = LLVM_COPTS,
- defines = LLVM_DEFINES,
+ copts = llvm_copts,
deps = [
":config",
":core",
@@ -1900,8 +1846,7 @@ cc_library(
"include/llvm/ExecutionEngine/RTDyldMemoryManager.h",
"include/llvm/ExecutionEngine/RuntimeDyld*.h",
]),
- copts = LLVM_COPTS,
- defines = LLVM_DEFINES,
+ copts = llvm_copts,
deps = [
":config",
":mc",
@@ -1929,8 +1874,7 @@ cc_library(
"include/llvm/Transforms/IPO.h",
"include/llvm/Transforms/IPO/SCCP.h",
]),
- copts = LLVM_COPTS,
- defines = LLVM_DEFINES,
+ copts = llvm_copts,
deps = [
":aggressive_inst_combine",
":analysis",
@@ -1956,8 +1900,7 @@ cc_library(
"include/llvm/CodeGen/SelectionDAG/*.def",
"include/llvm/CodeGen/SelectionDAG/*.inc",
]),
- copts = LLVM_COPTS,
- defines = LLVM_DEFINES,
+ copts = llvm_copts,
deps = [
":analysis",
":code_gen",
@@ -1976,14 +1919,12 @@ cc_library(
"lib/Support/*.c",
"lib/Support/*.cpp",
"lib/Support/*.inc",
- "lib/Support/Unix/*.inc",
- "lib/Support/Unix/*.h",
"include/llvm-c/*.h",
"include/llvm/CodeGen/MachineValueType.h",
"include/llvm/BinaryFormat/COFF.h",
"include/llvm/BinaryFormat/MachO.h",
"lib/Support/*.h",
- ]),
+ ] + llvm_support_platform_specific_srcs_glob),
hdrs = glob([
"include/llvm/Support/*.h",
"include/llvm/Support/*.def",
@@ -1995,8 +1936,7 @@ cc_library(
"include/llvm/BinaryFormat/MachO.def",
"include/llvm/Support/VCSRevision.h",
],
- copts = LLVM_COPTS,
- defines = LLVM_DEFINES,
+ copts = llvm_copts,
deps = [
":config",
":demangle",
@@ -2019,8 +1959,7 @@ cc_library(
"include/llvm/TableGen/*.inc",
"include/llvm/Target/*.def",
]),
- copts = LLVM_COPTS,
- defines = LLVM_DEFINES,
+ copts = llvm_copts,
deps = [
":config",
":mc",
@@ -2046,8 +1985,7 @@ cc_library(
"include/llvm/CodeGen/*.def",
"include/llvm/CodeGen/*.inc",
]),
- copts = LLVM_COPTS,
- defines = LLVM_DEFINES,
+ copts = llvm_copts,
deps = [
":analysis",
":config",
@@ -2072,8 +2010,7 @@ cc_library(
"include/llvm/Transforms/Utils/*.def",
"include/llvm/Transforms/Utils/*.inc",
]),
- copts = LLVM_COPTS,
- defines = LLVM_DEFINES,
+ copts = llvm_copts,
deps = [
":analysis",
":config",
@@ -2097,8 +2034,7 @@ cc_library(
"include/llvm/Transforms/Vectorize/*.inc",
"include/llvm/Transforms/Vectorize.h",
]),
- copts = LLVM_COPTS,
- defines = LLVM_DEFINES,
+ copts = llvm_copts,
deps = [
":analysis",
":config",
@@ -2122,8 +2058,7 @@ cc_library(
"include/llvm/Target/X86/AsmParser/*.inc",
"lib/Target/X86/AsmParser/*.h",
]),
- copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/X86"],
- defines = LLVM_DEFINES,
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/X86"],
deps = [
":config",
":mc",
@@ -2148,8 +2083,7 @@ cc_library(
"include/llvm/Target/X86/InstPrinter/*.inc",
"lib/Target/X86/InstPrinter/*.h",
]),
- copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/X86"],
- defines = LLVM_DEFINES,
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/X86"],
deps = [
":config",
":mc",
@@ -2173,8 +2107,7 @@ cc_library(
"include/llvm/Target/X86/*.inc",
"lib/Target/X86/*.h",
]),
- copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/X86"],
- defines = LLVM_DEFINES,
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/X86"],
deps = [
":analysis",
":asm_printer",
@@ -2207,8 +2140,7 @@ cc_library(
"include/llvm/Target/X86/MCTargetDesc/*.inc",
"lib/Target/X86/MCTargetDesc/*.h",
]),
- copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/X86"],
- defines = LLVM_DEFINES,
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/X86"],
deps = [
":config",
":mc",
@@ -2233,8 +2165,7 @@ cc_library(
"include/llvm/Target/X86/Disassembler/*.inc",
"lib/Target/X86/Disassembler/*.h",
]),
- copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/X86"],
- defines = LLVM_DEFINES,
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/X86"],
deps = [
":config",
":mc_disassembler",
@@ -2257,8 +2188,7 @@ cc_library(
"include/llvm/Target/X86/TargetInfo/*.inc",
"lib/Target/X86/TargetInfo/*.h",
]),
- copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/X86"],
- defines = LLVM_DEFINES,
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/X86"],
deps = [
":config",
":mc",
@@ -2280,8 +2210,7 @@ cc_library(
"include/llvm/Target/X86/Utils/*.inc",
"lib/Target/X86/Utils/*.h",
]),
- copts = LLVM_COPTS + ["-Iexternal/llvm/lib/Target/X86"],
- defines = LLVM_DEFINES,
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/X86"],
deps = [
":code_gen",
":config",
diff --git a/third_party/llvm/llvm.bzl b/third_party/llvm/llvm.bzl
index 2e809e5f14..dfdacafceb 100644
--- a/third_party/llvm/llvm.bzl
+++ b/third_party/llvm/llvm.bzl
@@ -226,9 +226,9 @@ llvm_all_cmake_vars = select({
})
-LLVM_LINKOPTS = ["-ldl", "-lm", "-lpthread"]
+llvm_linkopts = ["-ldl", "-lm", "-lpthread"]
-LLVM_DEFINES = [
+llvm_defines = [
"LLVM_ENABLE_STATS",
"__STDC_LIMIT_MACROS",
"__STDC_CONSTANT_MACROS",
@@ -237,4 +237,11 @@ LLVM_DEFINES = [
"LLVM_BUILD_GLOBAL_ISEL",
]
-LLVM_COPTS = []
+llvm_copts = []
+
+# Platform specific sources for libSupport.
+
+llvm_support_platform_specific_srcs_glob = [
+ "lib/Support/Unix/*.inc",
+ "lib/Support/Unix/*.h",
+]
diff --git a/third_party/mkl/LICENSE b/third_party/mkl/LICENSE
new file mode 100644
index 0000000000..9c8f3ea087
--- /dev/null
+++ b/third_party/mkl/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "{}"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright {yyyy} {name of copyright owner}
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License. \ No newline at end of file
diff --git a/third_party/mkl_dnn/build_defs.bzl b/third_party/mkl_dnn/build_defs.bzl
index 108d82e683..7ce2a7d9b0 100644
--- a/third_party/mkl_dnn/build_defs.bzl
+++ b/third_party/mkl_dnn/build_defs.bzl
@@ -9,5 +9,5 @@ def if_mkl_open_source_only(if_true, if_false = []):
"""
return select({
str(Label("//third_party/mkl_dnn:using_mkl_dnn_only")): if_true,
- "//conditions:default": if_false
- }) \ No newline at end of file
+ "//conditions:default": if_false,
+ })
diff --git a/third_party/nanopb.BUILD b/third_party/nanopb.BUILD
new file mode 100644
index 0000000000..d21866911b
--- /dev/null
+++ b/third_party/nanopb.BUILD
@@ -0,0 +1,23 @@
+# Description:
+# Nanopb, a tiny ANSI C protobuf implementation for use on embedded devices.
+
+licenses(["notice"]) # zlib license
+
+exports_files(["LICENSE.txt"])
+
+cc_library(
+ name = "nanopb",
+ srcs = [
+ "pb_common.c",
+ "pb_decode.c",
+ "pb_encode.c",
+ ],
+ hdrs = [
+ "pb.h",
+ "pb_common.h",
+ "pb_decode.h",
+ "pb_encode.h",
+ ],
+ includes = ["."],
+ visibility = ["//visibility:public"],
+)
diff --git a/third_party/nasm.BUILD b/third_party/nasm.BUILD
index 341d58068b..89330eac54 100644
--- a/third_party/nasm.BUILD
+++ b/third_party/nasm.BUILD
@@ -8,45 +8,93 @@ exports_files(["LICENSE"])
cc_binary(
name = "nasm",
srcs = [
- "assemble.c",
- "assemble.h",
- "compiler.h",
- "crc64.c",
- "directiv.c",
- "directiv.h",
- "disp8.c",
- "disp8.h",
- "eval.c",
- "eval.h",
- "exprlib.c",
- "float.c",
- "float.h",
- "hashtbl.c",
- "hashtbl.h",
- "iflag.c",
- "iflag.h",
- "iflaggen.h",
- "ilog2.c",
- "insns.h",
- "insnsa.c",
- "insnsb.c",
- "insnsi.h",
- "labels.c",
- "labels.h",
- "lib/strlcpy.c",
- "listing.c",
- "listing.h",
- "macros.c",
- "md5.h",
- "md5c.c",
- "nasm.c",
- "nasm.h",
- "nasmlib.c",
- "nasmlib.h",
- "opflags.h",
+ "asm/assemble.c",
+ "asm/assemble.h",
+ "asm/directbl.c",
+ "asm/directiv.c",
+ "asm/directiv.h",
+ "asm/error.c",
+ "asm/eval.c",
+ "asm/eval.h",
+ "asm/exprdump.c",
+ "asm/exprlib.c",
+ "asm/float.c",
+ "asm/float.h",
+ "asm/labels.c",
+ "asm/listing.c",
+ "asm/listing.h",
+ "asm/nasm.c",
+ "asm/parser.c",
+ "asm/parser.h",
+ "asm/pptok.c",
+ "asm/pptok.h",
+ "asm/pragma.c",
+ "asm/preproc.c",
+ "asm/preproc.h",
+ "asm/preproc-nop.c",
+ "asm/quote.c",
+ "asm/quote.h",
+ "asm/rdstrnum.c",
+ "asm/segalloc.c",
+ "asm/stdscan.c",
+ "asm/stdscan.h",
+ "asm/strfunc.c",
+ "asm/tokens.h",
+ "asm/tokhash.c",
+ "common/common.c",
+ "config/unknown.h",
+ "disasm/disasm.c",
+ "disasm/disasm.h",
+ "disasm/sync.c",
+ "disasm/sync.h",
+ "include/compiler.h",
+ "include/disp8.h",
+ "include/error.h",
+ "include/hashtbl.h",
+ "include/iflag.h",
+ "include/insns.h",
+ "include/labels.h",
+ "include/md5.h",
+ "include/nasm.h",
+ "include/nasmint.h",
+ "include/nasmlib.h",
+ "include/opflags.h",
+ "include/perfhash.h",
+ "include/raa.h",
+ "include/rbtree.h",
+ "include/rdoff.h",
+ "include/saa.h",
+ "include/strlist.h",
+ "include/tables.h",
+ "include/ver.h",
+ "macros/macros.c",
+ "nasmlib/badenum.c",
+ "nasmlib/bsi.c",
+ "nasmlib/crc64.c",
+ "nasmlib/file.c",
+ "nasmlib/file.h",
+ "nasmlib/filename.c",
+ "nasmlib/hashtbl.c",
+ "nasmlib/ilog2.c",
+ "nasmlib/malloc.c",
+ "nasmlib/md5c.c",
+ "nasmlib/mmap.c",
+ "nasmlib/path.c",
+ "nasmlib/perfhash.c",
+ "nasmlib/raa.c",
+ "nasmlib/rbtree.c",
+ "nasmlib/readnum.c",
+ "nasmlib/realpath.c",
+ "nasmlib/saa.c",
+ "nasmlib/srcfile.c",
+ "nasmlib/string.c",
+ "nasmlib/strlist.c",
+ "nasmlib/ver.c",
+ "nasmlib/zerobuf.c",
"output/codeview.c",
"output/dwarf.h",
"output/elf.h",
+ "output/legacy.c",
"output/nulldbg.c",
"output/nullout.c",
"output/outaout.c",
@@ -56,9 +104,6 @@ cc_binary(
"output/outdbg.c",
"output/outelf.c",
"output/outelf.h",
- "output/outelf32.c",
- "output/outelf64.c",
- "output/outelfx32.c",
"output/outform.c",
"output/outform.h",
"output/outieee.c",
@@ -69,35 +114,31 @@ cc_binary(
"output/outrdf2.c",
"output/pecoff.h",
"output/stabs.h",
- "parser.c",
- "parser.h",
- "pptok.c",
- "pptok.h",
- "preproc.c",
- "preproc.h",
- "preproc-nop.c",
- "quote.c",
- "quote.h",
- "raa.c",
- "raa.h",
- "rbtree.c",
- "rbtree.h",
- "rdoff/rdoff.h",
- "realpath.c",
- "regflags.c",
- "regs.h",
- "regvals.c",
- "saa.c",
- "saa.h",
- "srcfile.c",
- "stdscan.c",
- "stdscan.h",
- "strfunc.c",
- "tables.h",
- "tokens.h",
- "tokhash.c",
- "ver.c",
+ "stdlib/snprintf.c",
+ "stdlib/strlcpy.c",
+ "stdlib/strnlen.c",
+ "stdlib/vsnprintf.c",
"version.h",
+ "x86/disp8.c",
+ "x86/iflag.c",
+ "x86/iflaggen.h",
+ "x86/insnsa.c",
+ "x86/insnsb.c",
+ "x86/insnsd.c",
+ "x86/insnsi.h",
+ "x86/insnsn.c",
+ "x86/regdis.c",
+ "x86/regdis.h",
+ "x86/regflags.c",
+ "x86/regs.c",
+ "x86/regs.h",
+ "x86/regvals.c",
+ ],
+ includes = [
+ "asm",
+ "include",
+ "output",
+ "x86",
],
copts = select({
":windows": [],
@@ -110,7 +151,10 @@ cc_binary(
defines = select({
":windows": [],
":windows_msvc": [],
- "//conditions:default": ["HAVE_SNPRINTF"],
+ "//conditions:default": [
+ "HAVE_SNPRINTF",
+ "HAVE_SYS_TYPES_H",
+ ],
}),
visibility = ["@jpeg//:__pkg__"],
)
diff --git a/tools/bazel.rc b/tools/bazel.rc
index b3a9e6f0ef..913c4bc333 100644
--- a/tools/bazel.rc
+++ b/tools/bazel.rc
@@ -40,8 +40,6 @@ build:cuda --define=using_cuda=true --define=using_cuda_nvcc=true
build:cuda_clang --crosstool_top=@local_config_cuda//crosstool:toolchain
build:cuda_clang --define=using_cuda=true --define=using_cuda_clang=true --define=using_clang=true
-build:win-cuda --define=using_cuda=true --define=using_cuda_nvcc=true
-
build:mkl --define=using_mkl=true
build:sycl --crosstool_top=@local_config_sycl//crosstool:toolchain